Repository: ml-explore/mlx Branch: main Commit: 70a0da6fca8a Files: 879 Total size: 6.3 MB Directory structure: gitextract_c0zbkz84/ ├── .clang-format ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── bug_report.md │ ├── actions/ │ │ ├── build-cuda-release/ │ │ │ └── action.yml │ │ ├── build-docs/ │ │ │ └── action.yml │ │ ├── build-linux/ │ │ │ └── action.yml │ │ ├── build-linux-release/ │ │ │ └── action.yml │ │ ├── build-macos/ │ │ │ └── action.yml │ │ ├── build-macos-release/ │ │ │ └── action.yml │ │ ├── build-windows/ │ │ │ └── action.yml │ │ ├── setup-linux/ │ │ │ └── action.yml │ │ ├── setup-macos/ │ │ │ └── action.yml │ │ ├── setup-windows/ │ │ │ └── action.yml │ │ ├── test-linux/ │ │ │ └── action.yml │ │ └── test-windows/ │ │ └── action.yml │ ├── dependabot.yml │ ├── pull_request_template.md │ ├── scripts/ │ │ ├── build-sanitizer-tests.sh │ │ └── setup+build-cpp-linux-fedora-container.sh │ └── workflows/ │ ├── build_and_test.yml │ ├── documentation.yml │ ├── nightly.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGMENTS.md ├── CITATION.cff ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks/ │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── autograd.cpp │ │ ├── compare_devices.cpp │ │ ├── irregular_strides.cpp │ │ ├── single_ops.cpp │ │ └── time_utils.h │ ├── numpy/ │ │ ├── single_ops.py │ │ └── time_utils.py │ └── python/ │ ├── batch_matmul_bench.py │ ├── blas/ │ │ ├── bench_gemm.py │ │ └── bench_gemv.py │ ├── comparative/ │ │ ├── README.md │ │ ├── bench_mlx.py │ │ ├── bench_torch.py │ │ └── compare.py │ ├── compile_bench.py │ ├── conv1d_bench.py │ ├── conv2d_bench_cpu.py │ ├── conv2d_train_bench_cpu.py │ ├── conv2d_transpose_bench_cpu.py │ ├── conv3d_bench.py │ ├── conv3d_bench_cpu.py │ ├── conv3d_train_bench_cpu.py │ ├── conv3d_transpose_bench_cpu.py │ ├── conv_bench.py │ ├── conv_transpose_bench.py │ ├── conv_unaligned_bench.py │ ├── distributed_bench.py │ ├── einsum_bench.py │ ├── fft_bench.py │ ├── gather_bench.py │ ├── gather_mm_bench.py │ ├── gather_qmm_bench.py │ ├── hadamard_bench.py │ ├── large_gemm_bench.py │ ├── layer_norm_bench.py │ ├── masked_scatter.py │ ├── rms_norm_bench.py │ ├── rope_bench.py │ ├── scatter_bench.py │ ├── sdpa_bench.py │ ├── sdpa_vector_bench.py │ ├── segmented_mm_bench.py │ ├── single_ops.py │ ├── slice_update_bench.py │ ├── synchronize_bench.py │ └── time_utils.py ├── cmake/ │ ├── FindCUDNN.cmake │ ├── FindNCCL.cmake │ ├── Findnvpl.cmake │ └── extension.cmake ├── docs/ │ ├── .clang-format │ ├── .gitignore │ ├── .nojekyll │ ├── Doxyfile │ ├── Makefile │ ├── README.md │ ├── index.html │ ├── requirements.txt │ └── src/ │ ├── _templates/ │ │ ├── module-base-class.rst │ │ ├── nn-module-template.rst │ │ └── optimizers-template.rst │ ├── conf.py │ ├── cpp/ │ │ └── ops.rst │ ├── dev/ │ │ ├── custom_metal_kernels.rst │ │ ├── extensions.rst │ │ ├── metal_debugger.rst │ │ ├── metal_logging.rst │ │ └── mlx_in_cpp.rst │ ├── examples/ │ │ ├── data_parallelism.rst │ │ ├── linear_regression.rst │ │ ├── llama-inference.rst │ │ ├── mlp.rst │ │ └── tensor_parallelism.rst │ ├── index.rst │ ├── install.rst │ ├── python/ │ │ ├── array.rst │ │ ├── cuda.rst │ │ ├── data_types.rst │ │ ├── devices_and_streams.rst │ │ ├── distributed.rst │ │ ├── export.rst │ │ ├── fast.rst │ │ ├── fft.rst │ │ ├── linalg.rst │ │ ├── memory_management.rst │ │ ├── metal.rst │ │ ├── nn/ │ │ │ ├── distributed.rst │ │ │ ├── functions.rst │ │ │ ├── init.rst │ │ │ ├── layers.rst │ │ │ ├── losses.rst │ │ │ └── module.rst │ │ ├── nn.rst │ │ ├── ops.rst │ │ ├── optimizers/ │ │ │ ├── common_optimizers.rst │ │ │ ├── optimizer.rst │ │ │ └── schedulers.rst │ │ ├── optimizers.rst │ │ ├── random.rst │ │ ├── transforms.rst │ │ └── tree_utils.rst │ └── usage/ │ ├── compile.rst │ ├── distributed.rst │ ├── export.rst │ ├── function_transforms.rst │ ├── indexing.rst │ ├── launching_distributed.rst │ ├── lazy_evaluation.rst │ ├── numpy.rst │ ├── quick_start.rst │ ├── saving_and_loading.rst │ ├── unified_memory.rst │ └── using_streams.rst ├── examples/ │ ├── cmake_project/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ └── example.cpp │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── distributed.cpp │ │ ├── linear_regression.cpp │ │ ├── logistic_regression.cpp │ │ ├── metal_capture.cpp │ │ ├── timer.h │ │ └── tutorial.cpp │ ├── export/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── eval_mlp.cpp │ │ ├── eval_mlp.py │ │ ├── train_mlp.cpp │ │ └── train_mlp.py │ ├── extensions/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── axpby/ │ │ │ ├── axpby.cpp │ │ │ ├── axpby.h │ │ │ └── axpby.metal │ │ ├── bindings.cpp │ │ ├── mlx_sample_extensions/ │ │ │ └── __init__.py │ │ ├── pyproject.toml │ │ ├── requirements.txt │ │ ├── setup.py │ │ └── test.py │ └── python/ │ ├── linear_regression.py │ ├── logistic_regression.py │ └── qqmm.py ├── mlx/ │ ├── 3rdparty/ │ │ ├── .clang-format │ │ └── pocketfft.h │ ├── CMakeLists.txt │ ├── allocator.h │ ├── api.h │ ├── array.cpp │ ├── array.h │ ├── backend/ │ │ ├── common/ │ │ │ ├── CMakeLists.txt │ │ │ ├── binary.h │ │ │ ├── broadcasting.cpp │ │ │ ├── broadcasting.h │ │ │ ├── buffer_cache.h │ │ │ ├── common.cpp │ │ │ ├── compiled.cpp │ │ │ ├── compiled.h │ │ │ ├── copy.h │ │ │ ├── hadamard.h │ │ │ ├── load.cpp │ │ │ ├── matmul.h │ │ │ ├── quantized.h │ │ │ ├── reduce.cpp │ │ │ ├── reduce.h │ │ │ ├── slicing.cpp │ │ │ ├── slicing.h │ │ │ ├── ternary.h │ │ │ ├── unary.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── cpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── arange.h │ │ │ ├── arg_reduce.cpp │ │ │ ├── binary.cpp │ │ │ ├── binary.h │ │ │ ├── binary_ops.h │ │ │ ├── binary_two.h │ │ │ ├── cholesky.cpp │ │ │ ├── compiled.cpp │ │ │ ├── compiled_preamble.h │ │ │ ├── conv.cpp │ │ │ ├── copy.cpp │ │ │ ├── copy.h │ │ │ ├── device_info.cpp │ │ │ ├── device_info.h │ │ │ ├── distributed.cpp │ │ │ ├── eig.cpp │ │ │ ├── eigh.cpp │ │ │ ├── encoder.cpp │ │ │ ├── encoder.h │ │ │ ├── eval.cpp │ │ │ ├── eval.h │ │ │ ├── fft.cpp │ │ │ ├── gemm.h │ │ │ ├── gemms/ │ │ │ │ ├── bnns.cpp │ │ │ │ ├── cblas.cpp │ │ │ │ ├── simd_bf16.cpp │ │ │ │ ├── simd_fp16.cpp │ │ │ │ └── simd_gemm.h │ │ │ ├── hadamard.cpp │ │ │ ├── indexing.cpp │ │ │ ├── inverse.cpp │ │ │ ├── jit_compiler.cpp │ │ │ ├── jit_compiler.h │ │ │ ├── lapack.h │ │ │ ├── logsumexp.cpp │ │ │ ├── luf.cpp │ │ │ ├── make_compiled_preamble.ps1 │ │ │ ├── make_compiled_preamble.sh │ │ │ ├── masked_mm.cpp │ │ │ ├── matmul.cpp │ │ │ ├── primitives.cpp │ │ │ ├── qrf.cpp │ │ │ ├── quantized.cpp │ │ │ ├── reduce.cpp │ │ │ ├── scan.cpp │ │ │ ├── select.cpp │ │ │ ├── simd/ │ │ │ │ ├── accelerate_fp16_simd.h │ │ │ │ ├── accelerate_simd.h │ │ │ │ ├── base_simd.h │ │ │ │ ├── math.h │ │ │ │ ├── neon_fp16_simd.h │ │ │ │ ├── simd.h │ │ │ │ └── type.h │ │ │ ├── slicing.h │ │ │ ├── softmax.cpp │ │ │ ├── sort.cpp │ │ │ ├── svd.cpp │ │ │ ├── ternary.h │ │ │ ├── threefry.cpp │ │ │ ├── threefry.h │ │ │ ├── unary.cpp │ │ │ ├── unary.h │ │ │ └── unary_ops.h │ │ ├── cuda/ │ │ │ ├── CMakeLists.txt │ │ │ ├── allocator.cpp │ │ │ ├── allocator.h │ │ │ ├── arange.cu │ │ │ ├── arg_reduce.cu │ │ │ ├── bin2h.cmake │ │ │ ├── binary/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── add.cu │ │ │ │ ├── arctan2.cu │ │ │ │ ├── binary.cuh │ │ │ │ ├── bitwise_binary.cu │ │ │ │ ├── divide.cu │ │ │ │ ├── equal.cu │ │ │ │ ├── greater.cu │ │ │ │ ├── greater_equal.cu │ │ │ │ ├── less.cu │ │ │ │ ├── less_equal.cu │ │ │ │ ├── log_add_exp.cu │ │ │ │ ├── logical_and.cu │ │ │ │ ├── logical_or.cu │ │ │ │ ├── maximum.cu │ │ │ │ ├── minimum.cu │ │ │ │ ├── multiply.cu │ │ │ │ ├── not_equal.cu │ │ │ │ ├── power.cu │ │ │ │ ├── remainder.cu │ │ │ │ └── subtract.cu │ │ │ ├── binary_two.cu │ │ │ ├── compiled.cpp │ │ │ ├── conv/ │ │ │ │ ├── conv.h │ │ │ │ ├── gemm_conv.cu │ │ │ │ └── gemm_grouped_conv.cu │ │ │ ├── conv.cpp │ │ │ ├── copy/ │ │ │ │ ├── copy.cuh │ │ │ │ ├── copy_contiguous.cu │ │ │ │ ├── copy_general.cu │ │ │ │ ├── copy_general_dynamic.cu │ │ │ │ └── copy_general_input.cu │ │ │ ├── copy.cu │ │ │ ├── cublas_utils.cpp │ │ │ ├── cublas_utils.h │ │ │ ├── cuda.h │ │ │ ├── cuda_utils.h │ │ │ ├── cudnn_utils.cpp │ │ │ ├── cudnn_utils.h │ │ │ ├── custom_kernel.cpp │ │ │ ├── cutlass_utils.cuh │ │ │ ├── delayload.cpp │ │ │ ├── device/ │ │ │ │ ├── atomic_ops.cuh │ │ │ │ ├── binary_ops.cuh │ │ │ │ ├── cast_op.cuh │ │ │ │ ├── complex.cuh │ │ │ │ ├── config.h │ │ │ │ ├── fp16_math.cuh │ │ │ │ ├── gather.cuh │ │ │ │ ├── gather_axis.cuh │ │ │ │ ├── hadamard.cuh │ │ │ │ ├── indexing.cuh │ │ │ │ ├── scatter.cuh │ │ │ │ ├── scatter_axis.cuh │ │ │ │ ├── scatter_ops.cuh │ │ │ │ ├── slice_update.cuh │ │ │ │ ├── ternary_ops.cuh │ │ │ │ ├── unary_ops.cuh │ │ │ │ └── utils.cuh │ │ │ ├── device.cpp │ │ │ ├── device.h │ │ │ ├── device_info.cpp │ │ │ ├── distributed.cu │ │ │ ├── eval.cpp │ │ │ ├── event.cu │ │ │ ├── event.h │ │ │ ├── fence.cpp │ │ │ ├── fft.cu │ │ │ ├── gemms/ │ │ │ │ ├── cublas_gemm.cpp │ │ │ │ ├── cublas_gemm.h │ │ │ │ ├── cublas_gemm_batched_12_0.cpp │ │ │ │ ├── cublas_gemm_batched_12_9.cu │ │ │ │ ├── gemv.cu │ │ │ │ ├── gemv.h │ │ │ │ ├── grouped_gemm.h │ │ │ │ └── grouped_gemm_unaligned.cu │ │ │ ├── hadamard.cu │ │ │ ├── indexing.cpp │ │ │ ├── jit_module.cpp │ │ │ ├── jit_module.h │ │ │ ├── kernel_utils.cu │ │ │ ├── kernel_utils.cuh │ │ │ ├── layer_norm.cu │ │ │ ├── load.cpp │ │ │ ├── logsumexp.cu │ │ │ ├── lru_cache.h │ │ │ ├── matmul.cpp │ │ │ ├── no_cuda.cpp │ │ │ ├── primitives.cpp │ │ │ ├── quantized/ │ │ │ │ ├── affine_quantize.cu │ │ │ │ ├── convert_fp8.cu │ │ │ │ ├── cublas_qqmm.cpp │ │ │ │ ├── cublas_qqmm.h │ │ │ │ ├── fp_quantize.cu │ │ │ │ ├── mxfp8_quantize.cuh │ │ │ │ ├── no_qqmm_impl.cpp │ │ │ │ ├── nvfp4_quantize.cuh │ │ │ │ ├── qmm/ │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── fp_qmv.cu │ │ │ │ │ ├── qmm.cu │ │ │ │ │ ├── qmm.h │ │ │ │ │ ├── qmm_impl_sm80.cuh │ │ │ │ │ ├── qmm_impl_sm80_m16.cu │ │ │ │ │ ├── qmm_impl_sm80_m32.cu │ │ │ │ │ ├── qmm_impl_sm80_m64.cu │ │ │ │ │ ├── qmm_impl_sm90.cuh │ │ │ │ │ ├── qmm_impl_sm90_m128_n128_m2.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n16_m1.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n256_m2.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n32_m1.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n64_m2.cu │ │ │ │ │ └── qmv.cu │ │ │ │ ├── qqmm.cpp │ │ │ │ ├── qqmm_impl.cpp │ │ │ │ ├── qqmm_impl.h │ │ │ │ ├── qqmm_utils.cu │ │ │ │ ├── qqmm_utils.h │ │ │ │ ├── quantized.cpp │ │ │ │ ├── quantized.h │ │ │ │ └── quantized_utils.h │ │ │ ├── random.cu │ │ │ ├── reduce/ │ │ │ │ ├── all_reduce.cu │ │ │ │ ├── col_reduce.cu │ │ │ │ ├── init_reduce.cu │ │ │ │ ├── reduce.cuh │ │ │ │ ├── reduce_ops.cuh │ │ │ │ ├── reduce_utils.cuh │ │ │ │ └── row_reduce.cu │ │ │ ├── reduce.cu │ │ │ ├── rms_norm.cu │ │ │ ├── rope.cu │ │ │ ├── scaled_dot_product_attention.cpp │ │ │ ├── scaled_dot_product_attention.cu │ │ │ ├── scan.cu │ │ │ ├── slicing.cpp │ │ │ ├── softmax.cu │ │ │ ├── sort.cu │ │ │ ├── steel/ │ │ │ │ ├── defines.cuh │ │ │ │ ├── gemm.cuh │ │ │ │ ├── mma.cuh │ │ │ │ ├── tiles.cuh │ │ │ │ └── utils.cuh │ │ │ ├── ternary.cu │ │ │ ├── unary/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── abs.cu │ │ │ │ ├── arccos.cu │ │ │ │ ├── arccosh.cu │ │ │ │ ├── arcsin.cu │ │ │ │ ├── arcsinh.cu │ │ │ │ ├── arctan.cu │ │ │ │ ├── arctanh.cu │ │ │ │ ├── bitwise_invert.cu │ │ │ │ ├── ceil.cu │ │ │ │ ├── conjugate.cu │ │ │ │ ├── cos.cu │ │ │ │ ├── cosh.cu │ │ │ │ ├── erf.cu │ │ │ │ ├── erf_inv.cu │ │ │ │ ├── exp.cu │ │ │ │ ├── expm1.cu │ │ │ │ ├── floor.cu │ │ │ │ ├── imag.cu │ │ │ │ ├── log.cu │ │ │ │ ├── log1p.cu │ │ │ │ ├── logical_not.cu │ │ │ │ ├── negative.cu │ │ │ │ ├── real.cu │ │ │ │ ├── round.cu │ │ │ │ ├── sigmoid.cu │ │ │ │ ├── sign.cu │ │ │ │ ├── sin.cu │ │ │ │ ├── sinh.cu │ │ │ │ ├── sqrt.cu │ │ │ │ ├── square.cu │ │ │ │ ├── tan.cu │ │ │ │ ├── tanh.cu │ │ │ │ └── unary.cuh │ │ │ ├── utils.cpp │ │ │ ├── utils.h │ │ │ ├── vector_types.cuh │ │ │ ├── worker.cpp │ │ │ └── worker.h │ │ ├── gpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── copy.cpp │ │ │ ├── copy.h │ │ │ ├── device_info.h │ │ │ ├── eval.h │ │ │ ├── primitives.cpp │ │ │ ├── scan.h │ │ │ ├── slicing.cpp │ │ │ └── slicing.h │ │ ├── metal/ │ │ │ ├── CMakeLists.txt │ │ │ ├── allocator.cpp │ │ │ ├── allocator.h │ │ │ ├── binary.cpp │ │ │ ├── binary.h │ │ │ ├── compiled.cpp │ │ │ ├── conv.cpp │ │ │ ├── copy.cpp │ │ │ ├── custom_kernel.cpp │ │ │ ├── device.cpp │ │ │ ├── device.h │ │ │ ├── device_info.cpp │ │ │ ├── distributed.cpp │ │ │ ├── eval.cpp │ │ │ ├── event.cpp │ │ │ ├── fence.cpp │ │ │ ├── fft.cpp │ │ │ ├── hadamard.cpp │ │ │ ├── indexing.cpp │ │ │ ├── jit/ │ │ │ │ ├── includes.h │ │ │ │ └── indexing.h │ │ │ ├── jit_kernels.cpp │ │ │ ├── kernels/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── arange.h │ │ │ │ ├── arange.metal │ │ │ │ ├── arg_reduce.metal │ │ │ │ ├── atomic.h │ │ │ │ ├── bf16.h │ │ │ │ ├── bf16_math.h │ │ │ │ ├── binary.h │ │ │ │ ├── binary.metal │ │ │ │ ├── binary_ops.h │ │ │ │ ├── binary_two.h │ │ │ │ ├── binary_two.metal │ │ │ │ ├── cexpf.h │ │ │ │ ├── complex.h │ │ │ │ ├── conv.metal │ │ │ │ ├── copy.h │ │ │ │ ├── copy.metal │ │ │ │ ├── defines.h │ │ │ │ ├── erf.h │ │ │ │ ├── expm1f.h │ │ │ │ ├── fence.metal │ │ │ │ ├── fft/ │ │ │ │ │ ├── radix.h │ │ │ │ │ └── readwrite.h │ │ │ │ ├── fft.h │ │ │ │ ├── fft.metal │ │ │ │ ├── fp4.h │ │ │ │ ├── fp8.h │ │ │ │ ├── fp_quantized.h │ │ │ │ ├── fp_quantized.metal │ │ │ │ ├── fp_quantized_nax.h │ │ │ │ ├── fp_quantized_nax.metal │ │ │ │ ├── gemv.metal │ │ │ │ ├── gemv_masked.h │ │ │ │ ├── gemv_masked.metal │ │ │ │ ├── hadamard.h │ │ │ │ ├── indexing/ │ │ │ │ │ ├── gather.h │ │ │ │ │ ├── gather_axis.h │ │ │ │ │ ├── gather_front.h │ │ │ │ │ ├── indexing.h │ │ │ │ │ ├── masked_scatter.h │ │ │ │ │ ├── scatter.h │ │ │ │ │ └── scatter_axis.h │ │ │ │ ├── layer_norm.metal │ │ │ │ ├── logging.h │ │ │ │ ├── logsumexp.h │ │ │ │ ├── logsumexp.metal │ │ │ │ ├── quantized.h │ │ │ │ ├── quantized.metal │ │ │ │ ├── quantized_nax.h │ │ │ │ ├── quantized_nax.metal │ │ │ │ ├── quantized_utils.h │ │ │ │ ├── random.metal │ │ │ │ ├── reduce.h │ │ │ │ ├── reduce.metal │ │ │ │ ├── reduce_utils.h │ │ │ │ ├── reduction/ │ │ │ │ │ ├── ops.h │ │ │ │ │ ├── reduce_all.h │ │ │ │ │ ├── reduce_col.h │ │ │ │ │ ├── reduce_init.h │ │ │ │ │ └── reduce_row.h │ │ │ │ ├── rms_norm.metal │ │ │ │ ├── rope.metal │ │ │ │ ├── scaled_dot_product_attention.metal │ │ │ │ ├── scan.h │ │ │ │ ├── scan.metal │ │ │ │ ├── sdpa_vector.h │ │ │ │ ├── softmax.h │ │ │ │ ├── softmax.metal │ │ │ │ ├── sort.h │ │ │ │ ├── sort.metal │ │ │ │ ├── steel/ │ │ │ │ │ ├── attn/ │ │ │ │ │ │ ├── attn.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_attention.h │ │ │ │ │ │ │ ├── steel_attention.metal │ │ │ │ │ │ │ ├── steel_attention_nax.h │ │ │ │ │ │ │ └── steel_attention_nax.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── mma.h │ │ │ │ │ │ ├── nax.h │ │ │ │ │ │ ├── params.h │ │ │ │ │ │ └── transforms.h │ │ │ │ │ ├── conv/ │ │ │ │ │ │ ├── conv.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_conv.h │ │ │ │ │ │ │ ├── steel_conv.metal │ │ │ │ │ │ │ ├── steel_conv_3d.h │ │ │ │ │ │ │ ├── steel_conv_3d.metal │ │ │ │ │ │ │ ├── steel_conv_general.h │ │ │ │ │ │ │ └── steel_conv_general.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── loaders/ │ │ │ │ │ │ │ ├── loader_channel_l.h │ │ │ │ │ │ │ ├── loader_channel_n.h │ │ │ │ │ │ │ └── loader_general.h │ │ │ │ │ │ └── params.h │ │ │ │ │ ├── defines.h │ │ │ │ │ ├── gemm/ │ │ │ │ │ │ ├── gemm.h │ │ │ │ │ │ ├── gemm_nax.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_gemm_fused.h │ │ │ │ │ │ │ ├── steel_gemm_fused.metal │ │ │ │ │ │ │ ├── steel_gemm_fused_nax.h │ │ │ │ │ │ │ ├── steel_gemm_fused_nax.metal │ │ │ │ │ │ │ ├── steel_gemm_gather.h │ │ │ │ │ │ │ ├── steel_gemm_gather.metal │ │ │ │ │ │ │ ├── steel_gemm_gather_nax.h │ │ │ │ │ │ │ ├── steel_gemm_gather_nax.metal │ │ │ │ │ │ │ ├── steel_gemm_masked.h │ │ │ │ │ │ │ ├── steel_gemm_masked.metal │ │ │ │ │ │ │ ├── steel_gemm_segmented.h │ │ │ │ │ │ │ ├── steel_gemm_segmented.metal │ │ │ │ │ │ │ ├── steel_gemm_splitk.h │ │ │ │ │ │ │ ├── steel_gemm_splitk.metal │ │ │ │ │ │ │ ├── steel_gemm_splitk_nax.h │ │ │ │ │ │ │ └── steel_gemm_splitk_nax.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── mma.h │ │ │ │ │ │ ├── nax.h │ │ │ │ │ │ ├── params.h │ │ │ │ │ │ └── transforms.h │ │ │ │ │ ├── utils/ │ │ │ │ │ │ ├── integral_constant.h │ │ │ │ │ │ └── type_traits.h │ │ │ │ │ └── utils.h │ │ │ │ ├── ternary.h │ │ │ │ ├── ternary.metal │ │ │ │ ├── ternary_ops.h │ │ │ │ ├── unary.h │ │ │ │ ├── unary.metal │ │ │ │ ├── unary_ops.h │ │ │ │ └── utils.h │ │ │ ├── kernels.h │ │ │ ├── logsumexp.cpp │ │ │ ├── make_compiled_preamble.sh │ │ │ ├── matmul.cpp │ │ │ ├── matmul.h │ │ │ ├── metal.cpp │ │ │ ├── metal.h │ │ │ ├── no_metal.cpp │ │ │ ├── nojit_kernels.cpp │ │ │ ├── normalization.cpp │ │ │ ├── primitives.cpp │ │ │ ├── quantized.cpp │ │ │ ├── reduce.cpp │ │ │ ├── reduce.h │ │ │ ├── resident.cpp │ │ │ ├── resident.h │ │ │ ├── rope.cpp │ │ │ ├── scaled_dot_product_attention.cpp │ │ │ ├── scan.cpp │ │ │ ├── slicing.cpp │ │ │ ├── softmax.cpp │ │ │ ├── sort.cpp │ │ │ ├── ternary.cpp │ │ │ ├── ternary.h │ │ │ ├── unary.cpp │ │ │ ├── unary.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── no_cpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── compiled.cpp │ │ │ ├── device_info.cpp │ │ │ └── primitives.cpp │ │ └── no_gpu/ │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── apple_memory.h │ │ ├── device_info.cpp │ │ ├── eval.cpp │ │ ├── event.cpp │ │ ├── fence.cpp │ │ ├── linux_memory.h │ │ └── primitives.cpp │ ├── compile.cpp │ ├── compile.h │ ├── compile_impl.h │ ├── device.cpp │ ├── device.h │ ├── distributed/ │ │ ├── CMakeLists.txt │ │ ├── distributed.cpp │ │ ├── distributed.h │ │ ├── distributed_impl.h │ │ ├── jaccl/ │ │ │ ├── CMakeLists.txt │ │ │ ├── jaccl.cpp │ │ │ ├── jaccl.h │ │ │ ├── mesh.cpp │ │ │ ├── mesh.h │ │ │ ├── mesh_impl.h │ │ │ ├── no_jaccl.cpp │ │ │ ├── ring.cpp │ │ │ ├── ring.h │ │ │ ├── ring_impl.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── mpi/ │ │ │ ├── CMakeLists.txt │ │ │ ├── mpi.cpp │ │ │ ├── mpi.h │ │ │ ├── mpi_declarations.h │ │ │ └── no_mpi.cpp │ │ ├── nccl/ │ │ │ ├── CMakeLists.txt │ │ │ ├── nccl.cpp │ │ │ ├── nccl.h │ │ │ └── no_nccl.cpp │ │ ├── ops.cpp │ │ ├── ops.h │ │ ├── primitives.cpp │ │ ├── primitives.h │ │ ├── reduction_ops.h │ │ ├── ring/ │ │ │ ├── CMakeLists.txt │ │ │ ├── no_ring.cpp │ │ │ ├── ring.cpp │ │ │ └── ring.h │ │ ├── utils.cpp │ │ └── utils.h │ ├── dtype.cpp │ ├── dtype.h │ ├── dtype_utils.cpp │ ├── dtype_utils.h │ ├── einsum.cpp │ ├── einsum.h │ ├── event.h │ ├── export.cpp │ ├── export.h │ ├── export_impl.h │ ├── fast.cpp │ ├── fast.h │ ├── fast_primitives.h │ ├── fence.h │ ├── fft.cpp │ ├── fft.h │ ├── graph_utils.cpp │ ├── graph_utils.h │ ├── io/ │ │ ├── CMakeLists.txt │ │ ├── gguf.cpp │ │ ├── gguf.h │ │ ├── gguf_quants.cpp │ │ ├── load.cpp │ │ ├── load.h │ │ ├── no_gguf.cpp │ │ ├── no_safetensors.cpp │ │ └── safetensors.cpp │ ├── io.h │ ├── linalg.cpp │ ├── linalg.h │ ├── memory.h │ ├── mlx.h │ ├── ops.cpp │ ├── ops.h │ ├── primitives.cpp │ ├── primitives.h │ ├── random.cpp │ ├── random.h │ ├── scheduler.cpp │ ├── scheduler.h │ ├── small_vector.h │ ├── stream.h │ ├── threadpool.h │ ├── transforms.cpp │ ├── transforms.h │ ├── transforms_impl.h │ ├── types/ │ │ ├── bf16.h │ │ ├── complex.h │ │ ├── fp16.h │ │ ├── half_types.h │ │ └── limits.h │ ├── utils.cpp │ ├── utils.h │ ├── version.cpp │ └── version.h ├── mlx.pc.in ├── pyproject.toml ├── python/ │ ├── mlx/ │ │ ├── __main__.py │ │ ├── _distributed_utils/ │ │ │ ├── common.py │ │ │ ├── config.py │ │ │ └── launch.py │ │ ├── _reprlib_fix.py │ │ ├── _stub_patterns.txt │ │ ├── extension.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── init.py │ │ │ ├── layers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── base.py │ │ │ │ ├── containers.py │ │ │ │ ├── convolution.py │ │ │ │ ├── convolution_transpose.py │ │ │ │ ├── distributed.py │ │ │ │ ├── dropout.py │ │ │ │ ├── embedding.py │ │ │ │ ├── linear.py │ │ │ │ ├── normalization.py │ │ │ │ ├── pooling.py │ │ │ │ ├── positional_encoding.py │ │ │ │ ├── quantized.py │ │ │ │ ├── recurrent.py │ │ │ │ ├── transformer.py │ │ │ │ └── upsample.py │ │ │ ├── losses.py │ │ │ └── utils.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── optimizers.py │ │ │ └── schedulers.py │ │ ├── py.typed │ │ └── utils.py │ ├── src/ │ │ ├── CMakeLists.txt │ │ ├── array.cpp │ │ ├── buffer.h │ │ ├── constants.cpp │ │ ├── convert.cpp │ │ ├── convert.h │ │ ├── cuda.cpp │ │ ├── device.cpp │ │ ├── distributed.cpp │ │ ├── export.cpp │ │ ├── fast.cpp │ │ ├── fft.cpp │ │ ├── indexing.cpp │ │ ├── indexing.h │ │ ├── linalg.cpp │ │ ├── load.cpp │ │ ├── load.h │ │ ├── memory.cpp │ │ ├── metal.cpp │ │ ├── mlx.cpp │ │ ├── mlx_func.cpp │ │ ├── mlx_func.h │ │ ├── ops.cpp │ │ ├── random.cpp │ │ ├── small_vector.h │ │ ├── stream.cpp │ │ ├── transforms.cpp │ │ ├── trees.cpp │ │ ├── trees.h │ │ ├── utils.cpp │ │ └── utils.h │ └── tests/ │ ├── __main__.py │ ├── cuda_skip.py │ ├── mlx_distributed_tests.py │ ├── mlx_tests.py │ ├── mpi_test_distributed.py │ ├── nccl_test_distributed.py │ ├── ring_test_distributed.py │ ├── test_array.py │ ├── test_autograd.py │ ├── test_bf16.py │ ├── test_blas.py │ ├── test_compile.py │ ├── test_constants.py │ ├── test_conv.py │ ├── test_conv_transpose.py │ ├── test_device.py │ ├── test_double.py │ ├── test_einsum.py │ ├── test_eval.py │ ├── test_export_import.py │ ├── test_fast.py │ ├── test_fast_sdpa.py │ ├── test_fft.py │ ├── test_graph.py │ ├── test_init.py │ ├── test_linalg.py │ ├── test_load.py │ ├── test_losses.py │ ├── test_memory.py │ ├── test_nn.py │ ├── test_ops.py │ ├── test_optimizers.py │ ├── test_quantized.py │ ├── test_random.py │ ├── test_reduce.py │ ├── test_tree.py │ ├── test_upsample.py │ └── test_vmap.py ├── setup.py └── tests/ ├── CMakeLists.txt ├── allocator_tests.cpp ├── arg_reduce_tests.cpp ├── array_tests.cpp ├── autograd_tests.cpp ├── blas_tests.cpp ├── compile_tests.cpp ├── creations_tests.cpp ├── custom_vjp_tests.cpp ├── device_tests.cpp ├── einsum_tests.cpp ├── eval_tests.cpp ├── export_import_tests.cpp ├── fft_tests.cpp ├── gpu_tests.cpp ├── linalg_tests.cpp ├── load_tests.cpp ├── ops_tests.cpp ├── random_tests.cpp ├── scheduler_tests.cpp ├── test_teardown.cpp ├── tests.cpp ├── utils_tests.cpp └── vmap_tests.cpp ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ --- AccessModifierOffset: -1 AlignAfterOpenBracket: AlwaysBreak AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false AlignEscapedNewlinesLeft: true AlignOperands: false AlignTrailingComments: false AllowAllParametersOfDeclarationOnNextLine: false AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: Empty AllowShortIfStatementsOnASingleLine: false AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: true AlwaysBreakTemplateDeclarations: true BinPackArguments: false BinPackParameters: false BraceWrapping: AfterClass: false AfterControlStatement: false AfterEnum: false AfterFunction: false AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false BeforeCatch: false BeforeElse: false IndentBraces: false BreakBeforeBinaryOperators: None BreakBeforeBraces: Attach BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakAfterJavaFieldAnnotations: false BreakStringLiterals: false ColumnLimit: 80 CommentPragmas: '^ IWYU pragma:' ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DerivePointerAlignment: false DisableFormat: false ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] IncludeCategories: - Regex: '^<.*\.h(pp)?>' Priority: 1 - Regex: '^<.*' Priority: 2 - Regex: '.*' Priority: 3 IndentCaseLabels: true IndentWidth: 2 IndentWrappedFunctionNames: false KeepEmptyLinesAtTheStartOfBlocks: false MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None ObjCBlockIndentWidth: 2 ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: false PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 PointerAlignment: Left ReflowComments: true SortIncludes: true SpaceAfterCStyleCast: false SpaceBeforeAssignmentOperators: true SpaceBeforeParens: ControlStatements SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 SpacesInAngles: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false Standard: Cpp11 TabWidth: 8 UseTab: Never ... ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report about an issue you've encountered title: "[BUG] " labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Include code snippet ```python ``` **Expected behavior** A clear and concise description of what you expected to happen. **Desktop (please complete the following information):** - OS Version: [e.g. MacOS 14.1.2] - Version [e.g. 0.7.0] **Additional context** Add any other context about the problem here. ================================================ FILE: .github/actions/build-cuda-release/action.yml ================================================ name: 'Build CUDA wheel' description: 'Build CUDA wheel' inputs: arch: description: 'Platform architecture tag' required: true type: choice options: - x86_64 - aarch64 runs: using: "composite" steps: - name: Build package shell: bash env: CMAKE_ARGS: -DMLX_BUILD_CUDA=ON run: | pip install auditwheel build patchelf setuptools python setup.py clean --all MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w auditwheel repair dist/mlx_cuda*.whl \ --plat manylinux_2_35_${{ inputs.arch }} \ --exclude libcublas* \ --exclude libcuda* \ --exclude libcudnn* \ --exclude libnccl* \ --exclude libnvrtc* ================================================ FILE: .github/actions/build-docs/action.yml ================================================ name: 'Build Documentation' description: 'Build documentation' runs: using: "composite" steps: - name: Setup machine uses: ./.github/actions/setup-linux - name: Install dependencies shell: bash run: | sudo apt-get install -y doxygen source .venv/bin/activate pip install -r docs/requirements.txt pip install . -v - name: Build documentation shell: bash run: | source .venv/bin/activate cd docs doxygen make html O=-W - name: Create artifact tar shell: bash run: tar -cf artifact.tar -C docs --dereference build/html index.html # Do it manually because upload-pages-artifact requires gtar - name: Upload artifact id: upload-artifact uses: actions/upload-artifact@v5 with: name: github-pages path: artifact.tar retention-days: 1 if-no-files-found: error ================================================ FILE: .github/actions/build-linux/action.yml ================================================ name: 'Build and Test on Linux' inputs: toolkit: description: 'The toolkit to build with' required: false default: 'cpu' runs: using: "composite" steps: - name: Install Python package id: python_build shell: sh env: DEBUG: 1 CMAKE_ARGS: >- -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }} run: | if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then # There is no GPU in arm64 runner, use a common arch. CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=80" # Can not build tests and stubs when the built executables can not run. CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF -DMLX_BUILD_PYTHON_STUBS=OFF" fi # Install cpu-only torch to save space pip install torch --index-url https://download.pytorch.org/whl/cpu pip install --no-build-isolation -e ".[dev]" -v # Pass the CMAKE_ARGS to following steps. echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT - name: Build CPP only shell: bash run: | cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }} cmake --build build -j $(nproc) ================================================ FILE: .github/actions/build-linux-release/action.yml ================================================ name: 'Build Linux wheel' description: 'Build Linux wheel' inputs: build-backend: description: 'Build the backend mlx-cpu package' type: boolean required: false default: false arch: description: 'Platform architecture tag' required: true type: choice options: - x86_64 - aarch64 runs: using: "composite" steps: - name: Build MLX shell: bash run: pip install -e . -v - name: Build Python package shell: bash run: | pip install auditwheel patchelf build python setup.py clean --all MLX_BUILD_STAGE=1 python -m build -w auditwheel repair dist/mlx-*.whl \ --plat manylinux_2_35_${{ inputs.arch }} \ --exclude libmlx.so* \ --only-plat - name: Build backend package if: ${{ inputs.build-backend }} shell: bash run: | python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }} ================================================ FILE: .github/actions/build-macos/action.yml ================================================ name: 'Build and Test on macOS' description: 'Build and test MLX on macOS' runs: using: "composite" steps: - name: Install dependencies env: DEBUG: 1 CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" shell: bash -l {0} run: | pip install --upgrade pip pip install cmake setuptools typing_extensions pip install -e ".[dev]" -v - name: Install tests dependencies shell: bash -l {0} run: | pip install tensorflow - name: Run Python tests shell: bash -l {0} env: LOW_MEMORY: 1 run: | DEVICE=cpu python -m unittest discover -v python/tests DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m unittest discover -v python/tests mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi - name: Build example extension shell: bash -l {0} run: | cd examples/extensions pip install -r requirements.txt python setup.py build_ext --inplace python test.py - name: Build CPP only shell: bash -l {0} run: | mkdir -p build cd build cmake .. make -j $(sysctl -n hw.ncpu) - name: Run CPP tests shell: bash -l {0} env: DEVICE: gpu METAL_DEVICE_WRAPPER_TYPE: 1 METAL_DEBUG_ERROR_MODE: 0 run: ./build/tests/tests - name: Build small binary with JIT shell: bash -l {0} run: | mkdir -p build cd build cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ -DBUILD_SHARED_LIBS=ON \ -DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_SAFETENSORS=OFF \ -DMLX_BUILD_GGUF=OFF \ -DMLX_METAL_JIT=ON make -j $(sysctl -n hw.ncpu) - name: Run Python tests with JIT shell: bash -l {0} env: LOW_MEMORY: 1 DEVICE: gpu METAL_DEVICE_WRAPPER_TYPE: 1 METAL_DEBUG_ERROR_MODE: 0 run: | CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ pip install -e . -v python -m unittest discover -v python/tests ================================================ FILE: .github/actions/build-macos-release/action.yml ================================================ name: 'Build macOS release' description: 'Build MLX releases macOS' inputs: macos-target: description: 'macOS build target' required: false default: '15.0' build-backend: description: 'Build the backend mlx-metal package' type: boolean required: false default: false runs: using: "composite" steps: - name: Build Python package shell: bash -l {0} env: DEVELOPER_DIR: /Applications/Xcode-latest.app MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} run: | pip install build python setup.py clean --all MLX_BUILD_STAGE=1 python -m build -w - name: Build backend package if: ${{ inputs.build-backend }} shell: bash -l {0} env: DEVELOPER_DIR: /Applications/Xcode-latest.app MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} run: | python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w ================================================ FILE: .github/actions/build-windows/action.yml ================================================ name: 'Build on Windows' runs: using: 'composite' steps: - name: Install Python package id: python-build shell: cmd env: # For MSVC, Ninja/Release is the only config supported by ccache. CMAKE_ARGS: >- -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=cl -DCMAKE_RC_COMPILER=rc run: | uv pip install ".[dev]" -v :: Pass the CMAKE_ARGS to following steps. >>%GITHUB_OUTPUT% ECHO CMAKE_ARGS=%CMAKE_ARGS% - name: Build CPP only shell: cmd run: | cmake . -B build ${{ steps.python-build.outputs.CMAKE_ARGS }} cmake --build build -j %NUMBER_OF_PROCESSORS% ================================================ FILE: .github/actions/setup-linux/action.yml ================================================ name: 'Setup Linux Environment' description: 'Install dependencies for Linux builds' inputs: toolkit: description: 'Which toolkit to install' required: false default: 'cpu' python-version: description: 'Version of python to set up' required: false default: '3.14' use-ccache: description: 'Whether to enable ccache' required: false default: 'true' runs: using: "composite" steps: - name: Install common dependencies shell: bash run: | echo "::group::Install common dependencies" sudo apt-get update sudo apt-get install -y --no-install-recommends \ zip \ libblas-dev liblapack-dev liblapacke-dev \ openmpi-bin openmpi-common libopenmpi-dev echo "::endgroup::" - name: Use ccache if: ${{ inputs.use-ccache == 'true' }} uses: hendrikmuhs/ccache-action@v1.2 with: key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }} max-size: 1GB # ccache-action bug: running "apt-get update" fails on large arm runner. update-package-index: false - uses: actions/setup-python@v6 with: python-version: ${{ inputs.python-version }} - name: Setup Python venv shell: bash run: | echo "::group::Setup Python venv" python -m venv .venv source .venv/bin/activate pip install setuptools cmake typing_extensions echo PATH=$PATH >> $GITHUB_ENV # Search python packages in .venv echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo "::endgroup::" - name: Install CUDA toolkit if: ${{ startsWith(inputs.toolkit, 'cuda') }} shell: bash env: # Note: the CI machine does not meet CUDA 13's driver requirement. # Compatibility matrix: # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html PACKAGES: | { "cuda-12.6": "libcudnn9-dev-cuda-12 cuda-compiler-12-6 cuda-libraries-dev-12-6", "cuda-12.9": "libcudnn9-dev-cuda-12 cuda-compiler-12-9 cuda-libraries-dev-12-9", "cuda-13.0": "libcudnn9-dev-cuda-13 cuda-compiler-13-0 cuda-libraries-dev-13-0" } run: | echo "::group::Install CUDA toolkit" # The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is # Jetson specific. SBSA means Arm Server Base System Architecture. ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }} wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get install -y --no-install-recommends \ libnccl2 libnccl-dev \ ${{ fromJson(env.PACKAGES)[inputs.toolkit] }} echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH echo "::endgroup::" - name: CUDA packages and driver report if: ${{ startsWith(inputs.toolkit, 'cuda') }} shell: bash run: | echo "::group::Installed NVIDIA and CUDA packages" dpkg -l | egrep "cuda|nvidia" -i echo "::endgroup::" echo "::group::NVIDIA-SMI Status" nvidia-smi || true echo "::endgroup::" ================================================ FILE: .github/actions/setup-macos/action.yml ================================================ name: 'Setup macOS Environment' description: 'Install dependencies for macOS builds' inputs: python-version: description: 'Python version to use' required: false default: '3.10' runs: using: "composite" steps: - name: Install Homebrew packages shell: sh run: /opt/homebrew/bin/brew install openmpi - name: Verify MetalToolchain installed shell: bash run: xcodebuild -showComponent MetalToolchain - uses: conda-incubator/setup-miniconda@v3 with: miniconda-version: "latest" python-version: ${{ inputs.python-version }} ================================================ FILE: .github/actions/setup-windows/action.yml ================================================ name: 'Setup Windows environment' inputs: python-version: description: 'Version of python to set up' required: false default: '3.14' use-ccache: description: 'Whether to enable ccache' required: false default: 'true' runs: using: 'composite' steps: - name: Use ccache if: ${{ inputs.use-ccache == 'true' }} uses: hendrikmuhs/ccache-action@v1.2 with: key: ccache-${{ runner.os }}-${{ runner.arch }}-cpu max-size: 1GB - name: Setup Visual Studio cmd shell: cmd run: | :: Find out path to VS. pushd "C:\Program Files (x86)\Microsoft Visual Studio\Installer\" for /f "delims=" %%x in ('.\vswhere.exe -latest -property InstallationPath') do set VSPATH=%%x popd :: Import VS vars. call "%VSPATH%\VC\Auxiliary\Build\vcvarsall.bat" x64 :: Export to all steps. >>%GITHUB_ENV% set - uses: astral-sh/setup-uv@v7 - name: Setup Python venv shell: cmd run: | uv venv --python ${{ inputs.python-version }} call ".venv/Scripts/activate.bat" >>%GITHUB_ENV% set ================================================ FILE: .github/actions/test-linux/action.yml ================================================ name: 'Run Linux tests' inputs: has-gpu: description: 'Run GPU tests' required: false default: false runs: using: "composite" steps: - name: Run MPI tests shell: bash run: | echo "::group::MPI tests" mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py echo "::endgroup::" - name: Run distributed tests if: ${{ inputs.has-gpu == 'false' }} shell: bash run: | echo "::group::Distributed tests" mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) if grep -Fq '[WARN]' stderr.log ; then grep -F '[WARN]' stderr.log echo "Distributed ring test failed"; exit 1; fi echo "::endgroup::" - name: Run Python tests - CPU if: ${{ inputs.has-gpu == 'false' }} shell: bash env: DEVICE: cpu run: | echo "::group::Python tests - CPU" python -m unittest discover python/tests -v echo "::endgroup::" - name: Run Python tests - GPU if: ${{ inputs.has-gpu == 'true' }} shell: bash env: DEVICE: gpu run: | echo "::group::Python tests - GPU" python -m tests discover python/tests -v echo "::endgroup::" - name: Run CPP tests - CPU shell: bash env: DEVICE: cpu run: | echo "::group::CPP tests - CPU" ./build/tests/tests echo "::endgroup::" - name: Run CPP tests - GPU if: ${{ inputs.has-gpu == 'true' }} shell: bash env: DEVICE: gpu run: | echo "::group::CPP tests - GPU" ./build/tests/tests -sfe="*linalg_tests.cpp" echo "::endgroup::" ================================================ FILE: .github/actions/test-windows/action.yml ================================================ name: 'Run tests on Windows' runs: using: 'composite' steps: - name: Run Python tests - CPU shell: bash run: | echo "::group::Python tests - CPU" python -m unittest discover python/tests -v echo "::endgroup::" - name: Run CPP tests - CPU shell: bash env: DEVICE: cpu run: | echo "::group::CPP tests - CPU" ./build/tests.exe -tce="*gguf*,test random uniform" echo "::endgroup::" ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" ================================================ FILE: .github/pull_request_template.md ================================================ ## Proposed changes Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #. ## Checklist Put an `x` in the boxes that apply. - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the necessary documentation (if needed) ================================================ FILE: .github/scripts/build-sanitizer-tests.sh ================================================ #!/bin/bash set -ex export CMAKE_C_COMPILER=/usr/bin/clang export CMAKE_CXX_COMPILER=/usr/bin/clang++ BASE_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=DEBUG -DCMAKE_COMPILE_WARNING_AS_ERROR=ON" if [[ "$(uname -s)" != "Darwin" ]]; then BASE_CMAKE_ARGS+=" -DMLX_BUILD_METAL=OFF" fi run_test() { local sanitizer_name=$1 local cmake_sanitizer_flag="-DUSE_${sanitizer_name}=ON" echo " Running tests with: ${sanitizer_name}" case "$sanitizer_name" in ASAN) export ASAN_OPTIONS="detect_leaks=0" ;; UBSAN) export UBSAN_OPTIONS="halt_on_error=0:print_stacktrace=1" ;; TSAN) export TSAN_OPTIONS="" ;; esac rm -rf build mkdir -p build pushd build > /dev/null cmake .. ${BASE_CMAKE_ARGS} ${cmake_sanitizer_flag} make -j $(nproc) ./tests/tests popd > /dev/null unset ${sanitizer_name}_OPTIONS } sanitizer_arg=$(echo "$1" | tr '[:lower:]' '[:upper:]') if [[ "$sanitizer_arg" == "ASAN" || "$sanitizer_arg" == "UBSAN" || "$sanitizer_arg" == "TSAN" ]]; then run_test "$sanitizer_arg" echo " ${sanitizer_arg} test run completed successfully." else echo "Error: Invalid sanitizer '$1'. Please use one of: ASAN, UBSAN, TSAN." exit 1 fi ================================================ FILE: .github/scripts/setup+build-cpp-linux-fedora-container.sh ================================================ #!/bin/bash set -ex # [Setup] Install dependencies inside the container. dnf update -y dnf install -y \ blas-devel \ lapack-devel \ openblas-devel \ make \ cmake \ clang \ git dnf clean all # [C++] CI Build Sanity Check: Verifies code compilation, not for release. export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" export DEBUG=1 export CMAKE_C_COMPILER=/usr/bin/clang export CMAKE_CXX_COMPILER=/usr/bin/clang++ mkdir -p build pushd build cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG make -j $(nproc) ./tests/tests popd ================================================ FILE: .github/workflows/build_and_test.yml ================================================ name: Build and Test on: pull_request: push: branches: - main # For testing CI without starting a pull request: - test/* permissions: contents: read concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: check_lint: name: Check Lint runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v6 - uses: pre-commit/action@v3.0.1 linux_build_and_test: name: Linux (cpu, ${{ matrix.arch }}) needs: check_lint strategy: fail-fast: false matrix: arch: ['x86_64', 'aarch64'] runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux - uses: ./.github/actions/build-linux - uses: ./.github/actions/test-linux - run: df -h cuda_build_and_test: name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }}) if: github.repository == 'ml-explore/mlx' needs: check_lint strategy: fail-fast: false matrix: arch: ['x86_64', 'aarch64'] toolkit: ['cuda-12.6', 'cuda-12.9'] runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: toolkit: ${{ matrix.toolkit }} - uses: ./.github/actions/build-linux with: toolkit: ${{ matrix.toolkit }} - uses: ./.github/actions/test-linux if: matrix.arch == 'x86_64' with: has-gpu: true mac_build_and_test: name: macOS (${{ matrix.macos-target }}) if: github.repository == 'ml-explore/mlx' strategy: matrix: macos-target: ["14.0", "15.0", "26.0"] runs-on: [self-hosted, macos] env: MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }} needs: check_lint steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-macos - uses: ./.github/actions/build-macos windows_build_and_test: name: Windows (cpu, x86_64) needs: check_lint runs-on: windows-2025 steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-windows - uses: ./.github/actions/build-windows - uses: ./.github/actions/test-windows build_documentation: name: Build Documentation if: github.repository == 'ml-explore/mlx' runs-on: ubuntu-22.04 needs: check_lint steps: - uses: actions/checkout@v6 - uses: ./.github/actions/build-docs linux_sanitizer_build_and_test: name: Linux Sanitizer Tests (${{ matrix.sanitizer }}) needs: check_lint strategy: fail-fast: false matrix: sanitizer: [ASAN, UBSAN] # todo 12/16/2025: enable TSAN later + consider enabling ASAN for GPU backend tests. # sanitizer: [ASAN, UBSAN, TSAN] runs-on: ubuntu-22.04-arm steps: - name: Checkout code uses: actions/checkout@v6 - name: Install Dependencies run: | export DEBIAN_FRONTEND=noninteractive sudo apt-get update -y sudo apt-get install -y \ build-essential \ libblas-dev \ liblapacke-dev \ libopenblas-dev \ cmake \ clang \ git sudo apt-get clean sudo rm -rf /var/lib/apt/lists/* - name: Linux Build and Test with ${{ matrix.sanitizer }} run: | bash .github/scripts/build-sanitizer-tests.sh ${{ matrix.sanitizer }} linux_fedora_build_cpp: name: Linux Fedora (${{ matrix.arch }}) needs: check_lint strategy: fail-fast: false matrix: include: - host: ubuntu-22.04 arch: x86_64 - host: ubuntu-22.04-arm arch: aarch64 runs-on: ${{ matrix.host }} container: image: fedora:42 steps: - name: Checkout code uses: actions/checkout@v6 - name: CPP Build Test - No Release run: | bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh ================================================ FILE: .github/workflows/documentation.yml ================================================ name: Documentation on: workflow_dispatch: permissions: contents: read jobs: build: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v6 - uses: ./.github/actions/build-docs deploy: needs: build permissions: pages: write id-token: write runs-on: ubuntu-latest environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} steps: - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 ================================================ FILE: .github/workflows/nightly.yml ================================================ name: Nightly Build on: schedule: - cron: 33 6 * * 1-5 workflow_dispatch: permissions: contents: read jobs: build_linux_release: strategy: fail-fast: false matrix: python_version: ["3.10", "3.14"] runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux - uses: ./.github/actions/build-linux-release with: build-backend: ${{ matrix.python-version == '3.10' }} arch: "x86_64" - name: Upload mlx artifacts uses: actions/upload-artifact@v7 with: name: linux-wheels-${{ matrix.python_version }} path: wheelhouse/mlx-*.whl retention-days: 7 - name: Upload mlx-cpu artifacts if: matrix.python_version == '3.10' uses: actions/upload-artifact@v7 with: name: mlx-cpu path: wheelhouse/mlx_cpu-*.whl retention-days: 7 - run: df -h build_linux_with_tests: strategy: fail-fast: false matrix: python_version: ["3.11", "3.12", "3.13", "3.14"] runner: - ubuntu-22.04 - ubuntu-22.04-arm runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: python-version: ${{ matrix.python_version }} - uses: ./.github/actions/build-linux - uses: ./.github/actions/test-linux - run: df -h build_mac_release: if: github.repository == 'ml-explore/mlx' strategy: matrix: python-version: ["3.10", "3.13"] runs-on: [self-hosted, macos] steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-macos with: python-version: ${{ matrix.python-version }} - uses: ./.github/actions/build-macos - name: Build macOS 26 package uses: ./.github/actions/build-macos-release with: macos-target: 26.0 build-backend: ${{ matrix.python-version == '3.10' }} - name: Build macOS 15 package uses: ./.github/actions/build-macos-release with: macos-target: 15.0 build-backend: ${{ matrix.python-version == '3.10' }} - name: Build macOS 14 package uses: ./.github/actions/build-macos-release with: macos-target: 14.0 build-backend: ${{ matrix.python-version == '3.10' }} build_cuda_release: if: github.repository == 'ml-explore/mlx' runs-on: ubuntu-22-large steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: toolkit: 'cuda-12.9' - name: Build Python package uses: ./.github/actions/build-cuda-release with: toolkit: 'cuda-12.9' arch: 'x86_64' - name: Upload artifacts uses: actions/upload-artifact@v7 with: name: mlx-cuda path: wheelhouse/mlx_cuda_*.whl retention-days: 7 ================================================ FILE: .github/workflows/release.yml ================================================ name: PyPI Release on: push: tags: - 'v*' branches: - 'test-publish/*' workflow_dispatch: inputs: dry_run: description: 'Dry run (do not publish to PyPi)' required: false type: boolean dev_release: description: 'Development release (DEV_RELEASE=1)' required: false type: boolean permissions: contents: read jobs: build_documentation: if: github.repository == 'ml-explore/mlx' runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v6 - uses: ./.github/actions/build-docs deploy_documentation: if: ${{ !inputs.dry_run }} needs: build_documentation permissions: pages: write id-token: write runs-on: ubuntu-latest environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} steps: - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 build_linux_release: if: github.repository == 'ml-explore/mlx' strategy: matrix: python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] arch: ['x86_64', 'aarch64'] runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }} env: PYPI_RELEASE: 1 DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: python-version: ${{ matrix.python_version }} use-ccache: false - uses: ./.github/actions/build-linux-release with: build-backend: ${{ matrix.python_version == '3.10' }} arch: ${{ matrix.arch }} - name: Upload MLX artifacts uses: actions/upload-artifact@v7 with: overwrite: true name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }} path: wheelhouse/mlx-*.whl if-no-files-found: error - name: Upload CPU artifacts if: matrix.python_version == '3.10' uses: actions/upload-artifact@v7 with: overwrite: true name: mlx-cpu-${{ matrix.arch }} path: wheelhouse/mlx_cpu-*.whl if-no-files-found: error build_mac_release: if: github.repository == 'ml-explore/mlx' strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] runs-on: [self-hosted, macos] env: PYPI_RELEASE: 1 DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-macos with: python-version: ${{ matrix.python-version }} - name: Install dependencies shell: bash -l {0} run: | pip install --upgrade pip pip install cmake setuptools typing_extensions pip install -e . -v - name: Build macOS 14 package uses: ./.github/actions/build-macos-release with: macos-target: 14.0 build-backend: ${{ matrix.python-version == '3.10' }} - name: Build macOS 15 package uses: ./.github/actions/build-macos-release with: macos-target: 15.0 build-backend: ${{ matrix.python-version == '3.10' }} - name: Build macOS 26 package uses: ./.github/actions/build-macos-release with: macos-target: 26.0 build-backend: ${{ matrix.python-version == '3.10' }} - name: Upload MLX artifacts uses: actions/upload-artifact@v7 with: overwrite: true name: mac-wheels-${{ matrix.python-version }} path: dist/mlx-*.whl if-no-files-found: error - name: Upload Metal artifacts if: matrix.python-version == '3.10' uses: actions/upload-artifact@v7 with: overwrite: true name: mlx-metal path: dist/mlx_metal-*.whl if-no-files-found: error build_cuda_release: if: github.repository == 'ml-explore/mlx' strategy: matrix: arch: ['x86_64', 'aarch64'] toolkit: ['cuda-12.9', 'cuda-13.0'] runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }} env: PYPI_RELEASE: 1 DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }} steps: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: toolkit: ${{ matrix.toolkit }} use-ccache: false - name: Build Python package uses: ./.github/actions/build-cuda-release with: arch: ${{ matrix.arch }} - name: Upload artifacts uses: actions/upload-artifact@v7 with: overwrite: true name: mlx-${{ matrix.toolkit }}-${{ matrix.arch }} path: wheelhouse/mlx_cuda_*.whl if-no-files-found: error pypi-publish: name: Upload release to PyPI runs-on: ubuntu-latest needs: [build_linux_release, build_mac_release] permissions: id-token: write environment: name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }} url: https://pypi.org/p/mlx steps: - uses: actions/download-artifact@v8 with: pattern: linux-wheels-* merge-multiple: true path: dist - uses: actions/download-artifact@v8 with: pattern: mac-wheels-* merge-multiple: true path: dist - name: Display structure of downloaded files run: du -ah dist - name: Publish package distributions to PyPI if: ${{ !inputs.dry_run }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://upload.pypi.org/legacy/ pypi-publish-cuda: name: Upload CUDA release to PyPI runs-on: ubuntu-latest needs: [build_cuda_release] permissions: id-token: write environment: name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }} url: https://pypi.org/p/mlx-cuda steps: - uses: actions/download-artifact@v8 with: pattern: mlx-cuda-* merge-multiple: true path: dist - name: Display structure of downloaded files run: du -ah dist - name: Publish package distributions to PyPI if: ${{ !inputs.dry_run }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://upload.pypi.org/legacy/ pypi-publish-cpu: name: Upload CPU release to PyPI runs-on: ubuntu-latest needs: [build_linux_release] permissions: id-token: write environment: name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }} url: https://pypi.org/p/mlx-cpu steps: - uses: actions/download-artifact@v8 with: pattern: mlx-cpu-* merge-multiple: true path: dist - name: Display structure of downloaded files run: du -ah dist - name: Publish package distributions to PyPI if: ${{ !inputs.dry_run }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://upload.pypi.org/legacy/ pypi-publish-metal: name: Upload Metal release to PyPI runs-on: ubuntu-latest needs: [build_mac_release] permissions: id-token: write environment: name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }} url: https://pypi.org/p/mlx-metal steps: - uses: actions/download-artifact@v8 with: name: mlx-metal path: dist - name: Display structure of downloaded files run: du -ah dist - name: Publish package distributions to PyPI if: ${{ !inputs.dry_run }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://upload.pypi.org/legacy/ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # tensor files *.safe *.safetensors # Metal libraries *.metallib # Distribution / packaging python/mlx/core python/mlx/share python/mlx/include .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ venv/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST uv.lock .DS_Store # Prerequisites *.d # Compiled Object files *.slo *.lo *.o *.obj *.ilk # Precompiled Headers *.gch *.pch # Compiled Dynamic libraries *.so *.dylib *.dll # Fortran module files *.mod *.smod # Compiled Static libraries *.lai *.la *.a *.lib # Executables *.exe *.out *.app # Debug symbols *.pdb # VSCode .vscode/ # Jetbrains .cache/ # vim *.swp ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: check-yaml # - id: end-of-file-fixer # - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-clang-format rev: v21.1.8 hooks: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror rev: 26.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort rev: 7.0.0 hooks: - id: isort args: - --profile=black - repo: https://github.com/cheshirekow/cmake-format-precommit rev: v0.6.13 hooks: - id: cmake-format ================================================ FILE: ACKNOWLEDGMENTS.md ================================================ # Individual Contributors If you wish to be acknowledged for your contributions, please list your name with a short description of your contribution(s) below. For example: - Jane Smith: Added the `foo` and `bar` ops. MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. # Organizations MLX has received contributions from the following companies: - NVIDIA Corporation & Affiliates # Third-Party Software MLX leverages several third-party software, listed here together with their license copied verbatim. ## PocketFFT Copyright (C) 2010-2018 Max-Planck-Society All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ## metal-cpp Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright © 2023 Apple Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 title: mlx message: >- If you use this software, please cite it using the metadata from this file. type: software authors: - given-names: Awni family-names: Hannun affiliation: Apple - given-names: Jagrit family-names: Digani affiliation: Apple - given-names: Angelos family-names: Katharopoulos affiliation: Apple - given-names: Ronan family-names: Collobert affiliation: Apple repository-code: 'https://github.com/ml-explore' abstract: >- MLX: efficient and flexible machine learning on Apple silicon license: MIT ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.25) if(NOT MLX_VERSION) file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$") string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}") set(_major ${CMAKE_MATCH_1}) string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}") set(_minor ${CMAKE_MATCH_1}) string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}") set(_patch ${CMAKE_MATCH_1}) set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}") set(MLX_VERSION ${MLX_PROJECT_VERSION}) else() string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION ${MLX_VERSION}) endif() project( mlx LANGUAGES C CXX VERSION ${MLX_PROJECT_VERSION}) # ----------------------------- Setup ----------------------------- set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # ----------------------------- Configuration ----------------------------- option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF) option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF) option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF) option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF) # --------------------- Processor tests ------------------------- message( STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" ) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(NOT MLX_ENABLE_X64_MAC) message( FATAL_ERROR "Building for x86_64 on macOS is not supported." " If you are on an Apple silicon system, check the build" " documentation for possible fixes: " "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source" ) else() set(MLX_BUILD_METAL OFF) message(WARNING "Building for x86_64 arch is not officially supported.") endif() endif() else() set(MLX_BUILD_METAL OFF) endif() if(MLX_USE_CCACHE) find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) message(STATUS "Found CCache: ${CCACHE_PROGRAM}") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") endif() endif() if(USE_ASAN AND USE_TSAN) message( FATAL_ERROR "AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time." ) endif() set(SANITIZER_COMPILE_FLAGS "") set(SANITIZER_LINK_FLAGS "") if(USE_ASAN) if(WIN32 AND MSVC) list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address) list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address) else() list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address) list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address) if(CMAKE_SYSTEM_NAME STREQUAL "Linux") list(APPEND SANITIZER_LINK_FLAGS -lpthread) endif() endif() endif() if(USE_UBSAN) if(WIN32 AND MSVC) if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined) list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined) else() message( WARNING "UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC." ) endif() else() list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined) list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined) endif() endif() if(USE_TSAN) if(WIN32 AND MSVC) message( FATAL_ERROR "ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC." ) elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.") else() list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread) list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread) if(CMAKE_SYSTEM_NAME STREQUAL "Linux") list(APPEND SANITIZER_LINK_FLAGS -lpthread) endif() endif() endif() # ----------------------------- Lib ----------------------------- include(FetchContent) # Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24: cmake_policy(SET CMP0135 NEW) add_library(mlx) target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS}) target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS}) if(MLX_BUILD_CUDA) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) find_package(CUDNN REQUIRED) if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "13.1" AND CUDAToolkit_VERSION VERSION_LESS "13.2") message(FATAL_ERROR "CUDA Toolkit 13.1 is not supported.") endif() endif() if(MLX_BUILD_METAL) find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) find_library(QUARTZ_LIB QuartzCore) if(METAL_LIB) message(STATUS "Metal found ${METAL_LIB}") else() message( FATAL_ERROR "Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU") endif() if(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG) endif() # Throw an error if xcrun not found execute_process( COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY) if(${MACOS_SDK_VERSION} LESS 14.0) message( FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON") endif() message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip) if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0) message(FATAL_ERROR "MLX requires macOS >= 14.0") endif() set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") endif() execute_process( COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_MakeAvailable(metal_cpp) target_include_directories( mlx PUBLIC $ $) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) endif() if(CMAKE_SYSTEM_NAME STREQUAL "Linux") # With newer clang/gcc versions following libs are implicitly linked, but when # building on old distributions they need to be explicitly listed. target_link_libraries(mlx PRIVATE dl pthread) endif() if(WIN32) if(MSVC) # GGUF does not build with MSVC. set(MLX_BUILD_GGUF OFF) endif() # Generate DLL and EXE in the same dir, otherwise EXE will not be able to run. # This is only done when MLX is built as the top project. if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) endif() # Windows implementation of dlfcn.h APIs. FetchContent_Declare( dlfcn-win32 GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git GIT_TAG v1.4.2 EXCLUDE_FROM_ALL) block() set(BUILD_SHARED_LIBS OFF) FetchContent_MakeAvailable(dlfcn-win32) endblock() target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src") target_link_libraries(mlx PRIVATE dl) endif() if(MLX_BUILD_CPU) find_library(ACCELERATE_LIBRARY Accelerate) if(ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") set(MLX_BUILD_ACCELERATE ON) else() message(STATUS "Accelerate not found, using default backend.") set(MLX_BUILD_ACCELERATE OFF) endif() if(MLX_BUILD_ACCELERATE) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) add_compile_definitions(MLX_USE_ACCELERATE) add_compile_definitions(ACCELERATE_NEW_LAPACK) elseif(WIN32) # Download and link prebuilt binaries of OpenBLAS. Note that we can only # link with the dynamic library, the prebuilt binaries were built with MinGW # so static-linking would require linking with MinGW's runtime. FetchContent_Declare( openblas URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip" ) FetchContent_MakeAvailable(openblas) target_link_libraries(mlx PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib") target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include") # Make sure the DLL file is placed in the same dir with executables. set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll") add_custom_command( TARGET mlx POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE} ${CMAKE_BINARY_DIR}) else() if(${CMAKE_HOST_APPLE}) # The blas shipped in macOS SDK is not supported, search homebrew for # openblas instead. set(BLA_VENDOR OpenBLAS) set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") endif() # Search and link with lapack. find_package(LAPACK REQUIRED) if(NOT LAPACK_FOUND) message(FATAL_ERROR "Must have LAPACK installed") endif() find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include /usr/local/opt/openblas/include) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES}) # List blas after lapack otherwise we may accidentally incldue an old # version of lapack.h from the include dirs of blas. find_package(BLAS REQUIRED) if(NOT BLAS_FOUND) message(FATAL_ERROR "Must have BLAS installed") endif() # TODO find a cleaner way to do this find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include $ENV{BLAS_HOME}/include) message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES}) endif() else() set(MLX_BUILD_ACCELERATE OFF) endif() message(STATUS "Downloading json") FetchContent_Declare( json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) FetchContent_MakeAvailable(json) target_include_directories( mlx PRIVATE $) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) target_include_directories( mlx PUBLIC $ $) if(USE_SYSTEM_FMT) find_package(fmt REQUIRED) else() FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_TAG 12.1.0 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(fmt) endif() target_link_libraries(mlx PRIVATE $) if(MLX_BUILD_PYTHON_BINDINGS) message(STATUS "Building Python bindings.") find_package( Python 3.10 COMPONENTS Interpreter Development.Module REQUIRED) FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v2.10.2 GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nanobind) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) endif() if(MLX_BUILD_TESTS) include(CTest) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) endif() if(MLX_BUILD_EXAMPLES) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) endif() if(MLX_BUILD_BENCHMARKS) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) endif() # ----------------------------- Installation ----------------------------- include(GNUInstallDirs) if(WIN32) # Install DLLs to the same dir with extension file (core.pyd) on Windows. set(CMAKE_INSTALL_BINDIR ".") if(MLX_BUILD_CPU) # Install OpenBLAS. install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN) endif() endif() # Install library install( TARGETS mlx EXPORT MLXTargets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) # Install headers install( DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} COMPONENT headers FILES_MATCHING PATTERN "*.h" PATTERN "backend/metal/kernels.h" EXCLUDE) # Install metal dependencies if(MLX_BUILD_METAL) # Install metal cpp install( DIRECTORY ${metal_cpp_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp COMPONENT metal_cpp_source) endif() # Install cmake config set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake) set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake) set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX) install( EXPORT MLXTargets FILE MLXTargets.cmake DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) include(CMakePackageConfigHelpers) write_basic_package_version_file( ${MLX_CMAKE_BUILD_VERSION_CONFIG} COMPATIBILITY SameMajorVersion VERSION ${MLX_VERSION}) configure_package_config_file( ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} NO_CHECK_REQUIRED_COMPONENTS_MACRO PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR) install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) install(DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq [translations]: https://www.contributor-covenant.org/translations ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to MLX We want to make contributing to this project as easy and transparent as possible. ## Pull Requests 1. Fork and submit pull requests to the repo. 2. If you've added code that should be tested, add tests. 3. If a change is likely to impact efficiency, run some of the benchmarks before and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 4. If you've changed APIs, update the documentation. 5. Every PR should have passing tests and at least one review. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. This should install hooks for running `black` and `clang-format` to ensure consistent style for C++ and python code. You can also run the formatters manually as follows: ```shell clang-format -i file.cpp ``` ```shell black file.py ``` or run `pre-commit run --all-files` to check all files in the repo. ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. ## License By contributing to MLX, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ================================================ FILE: LICENSE ================================================ MIT License Copyright © 2023 Apple Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include CMakeLists.txt include mlx.pc.in recursive-include mlx/ * include cmake/* include python/src/* include python/mlx/py.typed # support type hinting as in PEP-561 ================================================ FILE: README.md ================================================ # MLX [**Quickstart**](#quickstart) | [**Installation**](#installation) | [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) | [**Examples**](#examples) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research. Some key features of MLX include: - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror the Python API. MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs that closely follow PyTorch to simplify building more complex models. - **Composable function transformations**: MLX supports composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization. - **Lazy computation**: Computations in MLX are lazy. Arrays are only materialized when needed. - **Dynamic graph construction**: Computation graphs in MLX are constructed dynamically. Changing the shapes of function arguments does not trigger slow compilations, and debugging is simple and intuitive. - **Multi-device**: Operations can run on any of the supported devices (currently the CPU and the GPU). - **Unified memory**: A notable difference from MLX and other frameworks is the *unified memory model*. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without transferring data. MLX is designed by machine learning researchers for machine learning researchers. The framework is intended to be user-friendly, but still efficient to train and deploy models. The design of the framework itself is also conceptually simple. We intend to make it easy for researchers to extend and improve MLX with the goal of quickly exploring new ideas. The design of MLX is inspired by frameworks like [NumPy](https://numpy.org/doc/stable/index.html), [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and [ArrayFire](https://arrayfire.org/). ## Examples The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a variety of examples, including: - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training. - Large-scale text generation with [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora). - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper). ## Quickstart See the [quick start guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html) in the documentation. ## Installation MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on macOS, run: ```bash pip install mlx ``` To install the CUDA backend on Linux, run: ```bash pip install mlx[cuda] ``` To install a CPU-only Linux package, run: ```bash pip install mlx[cpu] ``` Checkout the [documentation](https://ml-explore.github.io/mlx/build/html/install.html#) for more information on building the C++ and Python APIs from source. ## Contributing Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information on contributing to MLX. See the [docs](https://ml-explore.github.io/mlx/build/html/install.html) for more information on building from source, and running tests. We are grateful for all of [our contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute to MLX and wish to be acknowledged, please add your name to the list in your pull request. ## Citing MLX The MLX software suite was initially developed with equal contribution by Awni Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find MLX useful in your research and wish to cite it, please use the following BibTex entry: ```text @software{mlx2023, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, url = {https://github.com/ml-explore}, version = {0.0}, year = {2023}, } ``` ================================================ FILE: benchmarks/cpp/CMakeLists.txt ================================================ function(build_benchmark SRCFILE) get_filename_component(src_name ${SRCFILE} NAME_WE) set(target "${src_name}") add_executable(${target} ${SRCFILE}) target_link_libraries(${target} PRIVATE mlx) endfunction(build_benchmark) build_benchmark(single_ops.cpp) build_benchmark(irregular_strides.cpp) build_benchmark(compare_devices.cpp) build_benchmark(autograd.cpp) ================================================ FILE: benchmarks/cpp/autograd.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/mlx.h" #include "time_utils.h" namespace mx = mlx::core; void time_value_and_grad() { auto x = mx::ones({200, 1000}); mx::eval(x); auto fn = [](mx::array x) { for (int i = 0; i < 20; ++i) { x = mx::log(mx::exp(x)); } return mx::sum(x); }; auto grad_fn = mx::grad(fn); auto independent_value_and_grad = [&]() { auto value = fn(x); auto dfdx = grad_fn(x); return std::vector{value, dfdx}; }; TIME(independent_value_and_grad); auto value_and_grad_fn = mx::value_and_grad(fn); auto combined_value_and_grad = [&]() { auto [value, dfdx] = value_and_grad_fn(x); return std::vector{value, dfdx}; }; TIME(combined_value_and_grad); } int main() { std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_value_and_grad(); } ================================================ FILE: benchmarks/cpp/compare_devices.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/mlx.h" #include "time_utils.h" namespace mx = mlx::core; void time_add_op() { std::vector sizes(1, 1); for (int i = 0; i < 9; ++i) { sizes.push_back(10 * sizes.back()); } set_default_device(mx::Device::cpu); for (auto size : sizes) { auto a = mx::random::uniform({size}); auto b = mx::random::uniform({size}); mx::eval(a, b); std::cout << "Size " << size << std::endl; TIMEM("cpu", mx::add, a, b, mx::Device::cpu); TIMEM("gpu", mx::add, a, b, mx::Device::gpu); } } int main() { time_add_op(); } ================================================ FILE: benchmarks/cpp/irregular_strides.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/mlx.h" #include "time_utils.h" namespace mx = mlx::core; void time_irregular_binary_ops_1D() { auto device = mx::default_device(); int size = 1000000; int step = 2; auto a = mx::random::uniform({size}); auto b = mx::random::uniform({size}); mx::eval(a, b); a = slice(a, {0}, {size}, {step}); b = slice(b, {0}, {size}, {step}); TIMEM("1D strided", mx::add, a, b, device); } void time_irregular_binary_ops_2D() { auto device = mx::default_device(); int size = 2048; auto a = mx::random::uniform({size, size}); auto b = mx::random::uniform({size, size}); mx::eval(a, b); TIMEM("2D regular", mx::add, a, b, device); b = mx::transpose(b); mx::eval(b); TIMEM("2D mx::transpose", mx::add, a, b, device); b = mx::random::uniform({size}); mx::eval(b); TIMEM("2D broadcast dim 0", mx::add, a, b, device); b = mx::reshape(b, {size, 1}); mx::eval(b); TIMEM("2D broadcast dim 1", mx::add, a, b, device); } void time_irregular_binary_ops_3D() { auto device = mx::default_device(); int d0 = 32; int d1 = 512; int d2 = 512; auto a = mx::random::uniform({d0, d1, d2}); auto b = mx::random::uniform({d0, d1, d2}); TIMEM("3D regular", mx::add, a, b, device); b = mx::transpose(b, {0, 2, 1}); TIMEM("3D mx::transpose", mx::add, a, b, device); b = mx::random::uniform({d1, d2}); TIMEM("3D broadcast dim 0", mx::add, a, b, device); b = mx::random::uniform({d0, 1, d2}); TIMEM("3D broadcast dim 1", mx::add, a, b, device); b = mx::random::uniform({d0, d1, 1}); TIMEM("3D broadcast dim 2", mx::add, a, b, device); b = mx::random::uniform({d2}); TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device); b = mx::random::uniform({d1, 1}); TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device); b = mx::random::uniform({d0, 1, 1}); TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device); } void time_irregular_binary_ops_4D() { auto device = mx::default_device(); mx::Shape shape = {8, 8, 512, 512}; auto a = mx::random::uniform(shape); auto b = mx::random::uniform(shape); TIMEM("4D regular", mx::add, a, b, device); b = mx::transpose(b, {0, 1, 3, 2}); TIMEM("4D mx::transpose", mx::add, a, b, device); std::string om = "4D broadcast dims "; for (int i = 0; i < shape.size(); ++i) { shape[i] = 1; b = mx::random::uniform(shape); std::ostringstream msg; msg << om << i; TIMEM(msg.str(), mx::add, a, b, device); for (int j = i + 1; j < shape.size(); ++j) { shape[j] = 1; std::ostringstream msg; msg << om << i << ", " << j; b = mx::random::uniform(shape); TIMEM(msg.str(), mx::add, a, b, device); shape[j] = a.shape(j); for (int k = j + 1; k < shape.size(); ++k) { shape[k] = 1; std::ostringstream msg; msg << om << i << ", " << j << ", " << k; b = mx::random::uniform(shape); TIMEM(msg.str(), mx::add, a, b, device); shape[k] = a.shape(k); } } shape[i] = a.shape(i); } } void time_irregular_reshape() { auto device = mx::default_device(); mx::Shape shape; auto reshape_fn = [&shape, device](const mx::array& a) { return mx::reshape(a, shape, device); }; int size = 64; int d = 2 * size; auto a = mx::random::uniform({d, d, d}); shape = {8 * size, size, size}; TIMEM("3D contiguous", reshape_fn, a); a = mx::transpose(a); shape = {8 * size, size, size}; TIMEM("3D mx::transpose", reshape_fn, a); a = mx::transpose(a, {1, 2, 0}); shape = {8 * size, size, size}; TIMEM("3D mx::transpose dims 1 2", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d}); TIMEM("3D broadcast dim 0", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d}); TIMEM("3D broadcast dim 1", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d}); TIMEM("3D broadcast dim 2", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d}); TIMEM("3D broadcast dims 0, 1", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d}); TIMEM("3D broadcast dims 0, 2", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d}); TIMEM("3D broadcast dims 1, 2", reshape_fn, a); a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d}); TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); } void time_irregular_astype_1D() { auto device = mx::default_device(); int size = 1000000; int step = 2; auto a = mx::random::uniform({size}); a = slice(a, {0}, {size}, {step}); TIMEM("1D strided", mx::astype, a, mx::int32, device); } void time_irregular_astype_2D() { auto device = mx::default_device(); int size = 2048; mx::Shape shape = {size, size}; auto a = mx::random::uniform(shape); TIMEM("2D regular", mx::astype, a, mx::int32, device); a = mx::transpose(a); TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device); a = mx::broadcast_to(mx::random::uniform({size}), shape); TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device); a = mx::broadcast_to(mx::random::uniform({size, 1}), shape); TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device); } int main(int argc, char** argv) { if (argc > 1) { bool use_gpu = !strcmp(argv[1], "gpu"); set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu); } std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_irregular_binary_ops_1D(); time_irregular_binary_ops_2D(); time_irregular_binary_ops_3D(); time_irregular_binary_ops_4D(); time_irregular_reshape(); time_irregular_astype_1D(); time_irregular_astype_2D(); } ================================================ FILE: benchmarks/cpp/single_ops.cpp ================================================ // Copyright © 2023 Apple Inc. #include "mlx/mlx.h" #include "time_utils.h" namespace mx = mlx::core; void time_creation_ops() { int M = 2000; int N = 500; auto shape = {M, N}; auto full_fp32 = [&]() { return mx::full(shape, 3.3f); }; TIME(full_fp32); auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); }; TIME(zeros_fp32); auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); }; TIME(ones_fp32); auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); }; TIME(arange_fp32); } void time_type_conversions() { int M = 2000; int N = 500; auto shape = {M, N}; auto device = mx::default_device(); auto a = mx::zeros(shape, mx::float32); mx::eval(a); TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device); TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device); a = mx::zeros(shape, mx::int32); mx::eval(a); TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device); a = mx::zeros(shape, mx::bool_); mx::eval(a); TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device); TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device); TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device); } void time_random_generation() { int M = 2000; int N = 500; auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); }; TIME(uniform); auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); }; TIME(normal); } void time_unary_ops() { int M = 2000; int N = 500; auto device = mx::default_device(); auto a = mx::random::normal({M, N}); mx::eval(a); TIME(mlx::core::abs, a, device); TIME(mx::negative, a, device); TIME(mx::sign, a, device); TIME(mx::square, a, device); TIME(mlx::core::sqrt, a, device); TIME(mx::rsqrt, a, device); TIME(mlx::core::exp, a, device); a = mx::random::uniform({M, N}); TIME(mlx::core::log, a, device); } void time_binary_ops() { int M = 1000, N = 100, K = 10; auto condition = mx::random::randint(0, 2, {M, N, K}); auto a = mx::random::uniform({M, N, K}); auto b = mx::random::uniform({M, N, K}); auto device = mx::default_device(); mx::eval(a, b); TIME(mx::add, a, b, device); TIME(mx::subtract, a, b, device); TIME(mx::multiply, a, b, device); TIME(mx::divide, a, b, device); TIME(mx::maximum, a, b, device); TIME(mx::minimum, a, b, device); TIME(mx::where, condition, a, b, device); condition = mx::array({true}); b = mx::random::uniform({1}); mx::eval(b); TIMEM("scalar", mx::add, a, b, device); TIMEM("vector-scalar", mx::subtract, a, b, device); TIMEM("scalar-vector", mx::subtract, b, a, device); TIMEM("scalar", mx::multiply, a, b, device); TIMEM("vector-scalar", mx::divide, a, b, device); TIMEM("scalar-vector", mx::divide, b, a, device); TIMEM("scalar-vector", mx::where, condition, a, b, device); condition = mx::broadcast_to(mx::array({true}), {1000, 100}); a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); mx::eval(a, b); TIMEM("scalar-scalar broadcast", mx::add, a, b, device); TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device); TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device); TIMEM("scalar-scalar broadcast", mx::divide, a, b, device); TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device); } void time_strided_ops() { int M = 50, N = 50, O = 50, P = 50; auto a = mx::random::uniform({M, N, O, P}); auto b = mx::random::uniform({M, N, O, P}); auto device = mx::default_device(); mx::eval(a, b); TIMEM("non-strided", mx::add, a, b, device); a = mx::transpose(a, {1, 0, 2, 3}); b = mx::transpose(b, {3, 2, 0, 1}); mx::eval(a, b); TIMEM("strided", mx::add, a, b, device); } void time_comparisons() { int M = 1000, N = 100, K = 10; auto a = mx::random::uniform({M, N, K}); auto b = mx::random::uniform({M, N, K}); auto device = mx::default_device(); mx::eval(a, b); TIME(mx::equal, a, b, device); TIME(mx::greater, a, b, device); TIME(mx::greater_equal, a, b, device); TIME(mx::less, a, b, device); TIME(mx::less_equal, a, b, device); } void time_matvec() { int M = 2000, N = 200; auto a = mx::random::uniform({M, N}); auto b = mx::random::uniform({N}); auto c = mx::random::uniform({M}); mx::eval(a, b, c); auto matvec = [&]() { return mx::matmul(a, b); }; TIME(matvec); auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); }; TIME(matvec_transpose); } void time_matmul() { int M = 1000, N = 1000, K = 1000; auto a = mx::random::uniform({M, K}); auto b = mx::random::uniform({K, N}); auto device = mx::default_device(); mx::eval(a, b); TIME(mx::matmul, a, b, device); auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); }; TIME(transpose_matmul); } void time_reductions() { auto a = mx::random::normal({10000, 1000}); mx::eval(a); auto sum_all = [&a]() { return mx::sum(a, false); }; TIME(sum_all); auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); }; TIME(sum_along_0); auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); }; TIME(sum_along_1); auto prod_all = [&a]() { return mx::prod(a, false); }; TIME(prod_all); auto all_true = [&a]() { return mx::all(a, false); }; TIME(all_true); auto all_along_0 = [&a]() { return mx::all(a, 0, false); }; TIME(all_along_0); auto all_along_1 = [&a]() { return mx::all(a, 1, false); }; TIME(all_along_1); auto any_true = [&a]() { return mx::any(a, false); }; TIME(any_true); auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); }; TIME(argmin_along_0); auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; TIME(argmin_along_1); auto indices = mx::array({1}); auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1}); std::vector axes{0}; auto b = scatter(a, {indices}, updates, axes); mx::eval(b); auto max_along_0 = [&b]() { return mx::max(b, 0, false); }; TIME(max_along_0); auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; TIME(max_along_1); auto min_along_0 = [&b]() { return mx::min(b, 0, false); }; TIME(min_along_0); auto min_along_1 = [&b]() { return mx::min(b, 1, false); }; TIME(min_along_1); } void time_gather_scatter() { auto a = mx::random::normal({1000, 768}); mx::eval(a); auto indices = mx::random::randint(0, 1000, {256}); mx::eval(indices); auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); }; TIME(embedding_lookup); indices = mx::random::randint(0, 768 * 1000, {256 * 768}); mx::eval(indices); auto single_element_lookup = [&a, &indices]() { return mx::take(a, indices); }; TIME(single_element_lookup); indices = mx::random::randint(0, 1000, {256}); auto updates = mx::random::normal({256, 1, 768}); mx::eval(indices, updates); auto embedding_update = [&a, &indices, &updates]() { return scatter(a, indices, updates, 0); }; TIME(embedding_update); auto embedding_add = [&a, &indices, &updates]() { return scatter_add(a, indices, updates, 0); }; TIME(embedding_add); a = mx::reshape(a, {-1}); indices = mx::random::randint(0, 768 * 1000, {768 * 256}); updates = mx::random::normal({256 * 768, 1}); mx::eval(a, indices, updates); auto single_element_update = [&a, &indices, &updates]() { return scatter(a, indices, updates, 0); }; TIME(single_element_update); auto single_element_add = [&a, &indices, &updates]() { return scatter_add(a, indices, updates, 0); }; TIME(single_element_add); } void time_divmod() { auto a = mx::random::normal({1000}); auto b = mx::random::normal({1000}); mx::eval({a, b}); auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); }; TIME(divmod_fused); auto divmod_separate = [&a, &b]() { return std::vector{mx::floor_divide(a, b), mx::remainder(a, b)}; }; TIME(divmod_separate); } int main() { std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_creation_ops(); time_type_conversions(); time_unary_ops(); time_binary_ops(); time_strided_ops(); time_random_generation(); time_comparisons(); time_matvec(); time_matmul(); time_reductions(); time_gather_scatter(); time_divmod(); } ================================================ FILE: benchmarks/cpp/time_utils.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include #include #include "mlx/mlx.h" #define milliseconds(x) \ (std::chrono::duration_cast(x).count() / 1e6) #define time_now() std::chrono::high_resolution_clock::now() #define TIME(FUNC, ...) \ std::cout << "Timing " << #FUNC << " ... " << std::flush \ << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ << std::endl; #define TIMEM(MSG, FUNC, ...) \ std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \ << std::flush << std::setprecision(5) \ << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl; template double time_fn(F fn, Args&&... args) { // warmup for (int i = 0; i < 5; ++i) { eval(fn(std::forward(args)...)); } int num_iters = 100; auto start = time_now(); for (int i = 0; i < num_iters; i++) { eval(fn(std::forward(args)...)); } auto end = time_now(); return milliseconds(end - start) / static_cast(num_iters); } ================================================ FILE: benchmarks/numpy/single_ops.py ================================================ # Copyright © 2023 Apple Inc. import numpy as np from time_utils import time_fn def time_add(): a = np.ones((100, 100, 10), dtype=np.float32) b = np.ones((100, 100, 10), dtype=np.float32) time_fn(np.add, a, b) def time_matmul(): a = np.random.rand(1000, 500).astype(np.float32) b = np.random.rand(500, 1000).astype(np.float32) time_fn(np.matmul, a, b) def time_exp(): a = np.random.randn(1000, 100).astype(np.float32) time_fn(np.exp, a) def time_take(): a = np.random.rand(10000, 500) ids = np.random.randint(0, 10000, (20, 10)) ids = [idx.reshape(-1) for idx in np.split(ids, 20)] def random_take(): return [np.take(a, idx, 0) for idx in ids] time_fn(random_take) if __name__ == "__main__": time_add() time_matmul() time_exp() time_take() ================================================ FILE: benchmarks/numpy/time_utils.py ================================================ # Copyright © 2023 Apple Inc. import time def time_fn(fn, *args): print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): fn(*args) num_iters = 100 tic = time.perf_counter() for _ in range(num_iters): x = fn(*args) toc = time.perf_counter() msec = 1e3 * (toc - tic) / num_iters print(f"{msec:.5f} msec") ================================================ FILE: benchmarks/python/batch_matmul_bench.py ================================================ # Copyright © 2023 Apple Inc. import argparse import mlx.core as mx from time_utils import time_fn B = 8 T = 1024 D = 512 def time_batch_matmul(): mx.random.seed(3) a = mx.random.uniform(shape=(B, T, D)) b = mx.random.uniform(shape=(D, D)) c = mx.random.uniform(shape=(B, T, D)) mx.eval(a, b, c) time_fn(mx.matmul, a, b) def batch_vjp_first(): return mx.vjp(mx.matmul, [a, b], [c])[1][0] time_fn(batch_vjp_first) def batch_vjp_second(): return mx.vjp(mx.matmul, [a, b], [c])[1][1] time_fn(batch_vjp_second) def time_unbatch_matmul(): mx.random.seed(3) a = mx.random.uniform(shape=(B * T, D)) b = mx.random.uniform(shape=(D, D)) c = mx.random.uniform(shape=(B * T, D)) mx.eval(a, b, c) time_fn(mx.matmul, a, b) def unbatch_vjp_first(): return mx.matmul(c, mx.transpose(b)) time_fn(unbatch_vjp_first) def unbatch_vjp_second(): return mx.matmul(mx.transpose(a), c) time_fn(unbatch_vjp_second) if __name__ == "__main__": parser = argparse.ArgumentParser("MLX benchmarks.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") args = parser.parse_args() if args.gpu: mx.set_default_device(mx.gpu) else: mx.set_default_device(mx.cpu) time_batch_matmul() time_unbatch_matmul() ================================================ FILE: benchmarks/python/blas/bench_gemm.py ================================================ # Copyright © 2023 Apple Inc. import argparse import math import os import subprocess import time import mlx.core as mx import numpy as np import torch device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") N_warmup = 8 N_iter_bench = 80 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def gemm_nn_mlx(a, b): ys = [] for i in range(N_iter_func): y = a @ b ys.append(y) mx.eval(ys) return ys def gemm_nt_mlx(a, b): ys = [] for i in range(N_iter_func): y = a @ b.transpose((0, 2, 1)) ys.append(y) mx.eval(ys) return ys def gemm_tn_mlx(a, b): ys = [] for i in range(N_iter_func): y = a.transpose((0, 2, 1)) @ b ys.append(y) mx.eval(ys) return ys def gemm_tt_mlx(a, b): ys = [] for i in range(N_iter_func): y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1)) ys.append(y) mx.eval(ys) return ys @torch.no_grad() def gemm_nn_torch(a, b): ys = [] for i in range(N_iter_func): y = a @ b ys.append(y) torch.mps.synchronize() return ys @torch.no_grad() def gemm_nt_torch(a, b): ys = [] for i in range(N_iter_func): y = a @ b.transpose(-1, -2) ys.append(y) torch.mps.synchronize() return ys @torch.no_grad() def gemm_tn_torch(a, b): ys = [] for i in range(N_iter_func): y = a.transpose(-1, -2) @ b ys.append(y) torch.mps.synchronize() return ys @torch.no_grad() def gemm_tt_torch(a, b): ys = [] for i in range(N_iter_func): y = a.transpose(-1, -2) @ b.transpose(-1, -2) ys.append(y) torch.mps.synchronize() return ys def bench_shape(B, M, N, K, np_dtype, transpose="nn"): shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M) shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K) a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype) b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np).to("mps") b_pt = torch.from_numpy(b_np).to("mps") torch.mps.synchronize() f_mx = { "nn": gemm_nn_mlx, "nt": gemm_nt_mlx, "tn": gemm_tn_mlx, "tt": gemm_tt_mlx, }[transpose] f_pt = { "nn": gemm_nn_torch, "nt": gemm_nt_torch, "tn": gemm_tn_torch, "tt": gemm_tt_torch, }[transpose] time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) atol = 1e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol): print( f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}" ) return time_mlx, time_torch def get_gflop_count(B, M, N, K): return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run gemm benchmarks") dtypes = ("float32", "float16", "complex64") transposes = ("nn", "nt", "tn") shapes = ( (16, 234, 768, 3072), (1, 64, 64, 25344), (16, 1024, 1024, 1024), (1, 1024, 1024, 2048), (4, 1024, 1024, 4096), (4, 1024, 4096, 1024), (1, 4096, 4096, 4096), ) for dtype in dtypes: for transpose in transposes: for B, M, N, K in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose) gflop_count = get_gflop_count(B, M, N, K) gflops_mx = gflop_count / (time_mlx) gflops_pt = gflop_count / (time_torch) diff = gflops_mx / gflops_pt - 1.0 print( f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" ) if gflops_pt >= 2.0 * gflops_mx: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/blas/bench_gemv.py ================================================ # Copyright © 2023 Apple Inc. import os import subprocess import time import matplotlib.pyplot as plt import mlx.core as mx import numpy as np import torch results_dir = "./results" if not os.path.isdir(results_dir): os.mkdir(results_dir) device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") N_warmup = 5 N_iter_bench = 50 N_iter_func = 20 out_vec_sizes = [128, 512, 2048, 4096] in_vec_sizes = [128, 512, 2048, 4096] benchmark_vector_lens = [] benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2] benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2] benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2] benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000] benchmark_vector_lens.sort() def bench(f, m, v): for i in range(N_warmup): f(m, v) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(m, v) e = time.perf_counter_ns() return (e - s) * 1e-9 def gemv_mlx(m, v): ys = [] for i in range(N_iter_func): y = m @ v ys.append(y) mx.eval(ys) return ys def gemv_t_mlx(m, v): ys = [] for i in range(N_iter_func): y = v @ m ys.append(y) mx.eval(ys) return ys @torch.no_grad() def gemv_torch(m, v): ys = [] for i in range(N_iter_func): y = m @ v ys.append(y) torch.mps.synchronize() return ys @torch.no_grad() def gemv_t_torch(m, v): ys = [] for i in range(N_iter_func): y = v @ m ys.append(y) torch.mps.synchronize() return ys def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False): shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len) shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1) mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype) vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype) mat_mlx = mx.array(mat_npy) vec_mlx = mx.array(vec_npy) mat_trc = torch.from_numpy(mat_npy).to("mps") vec_trc = torch.from_numpy(vec_npy).to("mps") torch.mps.synchronize() time_torch = ( bench(gemv_t_torch, mat_trc, vec_trc) if transpose else bench(gemv_torch, mat_trc, vec_trc) ) time_mlx = ( bench(gemv_t_mlx, mat_mlx, vec_mlx) if transpose else bench(gemv_mlx, mat_mlx, vec_mlx) ) c_mlx = ( np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx) ) c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy) if not np.allclose(c_mlx, c_npy, atol=2e-5): print( f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}" ) return time_mlx, time_torch def get_gflop_count(in_vec_len, out_vec_len): return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float( 1024**3 ) def get_gbyte_size(in_vec_len, out_vec_len, np_dtype): n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len item_size = 4 if np_dtype == np.float32 else 2 return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3) def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] pyt_gb_s = [] pyt_gflops = [] for out_vec_len in out_vector_lens: gflop_count = get_gflop_count(in_vec_len, out_vec_len) gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) mlx_gb_s.append(gbyte_size / time_mlx) pyt_gb_s.append(gbyte_size / time_torch) mlx_gflops.append(gflop_count / time_mlx) pyt_gflops.append(gflop_count / time_torch) if transpose: title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}" else: title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}" ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX") ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch") ax.set_title(title) ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)") ax.legend() def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] pyt_gb_s = [] pyt_gflops = [] for in_vec_len in in_vector_lens: gflop_count = get_gflop_count(in_vec_len, out_vec_len) gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) mlx_gb_s.append(gbyte_size / time_mlx) pyt_gb_s.append(gbyte_size / time_torch) mlx_gflops.append(gflop_count / time_mlx) pyt_gflops.append(gflop_count / time_torch) if transpose: title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])" else: title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )" ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX") ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch") ax.set_title(title) ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)") ax.legend() for transpose in (False, True): for dtype in ("float32", "float16", "complex64"): fig, axs = plt.subplots( len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" ) for i, in_vec_len in enumerate(in_vec_sizes): bench_with_in_len( axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose ) for i, out_vec_len in enumerate(out_vec_sizes): bench_with_out_len( axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose ) op_name = "gemv_t" if transpose else "gemv" fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.savefig( os.path.join( results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" ) ) plt.close(fig) ================================================ FILE: benchmarks/python/comparative/README.md ================================================ Microbenchmarks comparing MLX to PyTorch ======================================== Implement the same microbenchmarks in MLX and PyTorch to compare and make a list of the biggest possible performance improvements and/or regressions. Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for instance to measure the times it takes to sum across the 3rd axis of the above tensor on the cpu. `compare.py` runs several benchmarks and compares the speed-up or lack thereof in comparison to PyTorch. Each bench script can be run with `--print-pid` to print the PID and wait for a key in order to ease attaching a debugger. ================================================ FILE: benchmarks/python/comparative/bench_mlx.py ================================================ # Copyright © 2023 Apple Inc. import argparse import math import os import time from functools import partial import mlx.core as mx import mlx.nn as nn def int_or_list(x): try: return int(x) except ValueError: return [int(xi) for xi in x.split(",")] def none_or_list(x): if x == "": return None else: return [int(xi) for xi in x.split(",")] def dtype_from_str(x): if x == "": return mx.float32 else: dt = getattr(mx, x) if not isinstance(dt, mx.Dtype): raise ValueError(f"{x} is not an mlx dtype") return dt def bench(f, *args): for i in range(10): f(*args) s = time.perf_counter() for i in range(100): f(*args) e = time.perf_counter() return e - s def matmul_square(x): y = x for i in range(10): y = y @ x mx.eval(y) return y def matmul(x, y): ys = [] for i in range(10): ys.append(x @ y) mx.eval(ys) def _quant_matmul(x, w, s, b, transpose, group_size, bits): ys = [] for i in range(10): ys.append( mx.quantized_matmul( x, w, s, b, transpose=transpose, group_size=group_size, bits=bits ) ) mx.eval(ys) quant_matmul = { "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2), "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4), "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8), "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), "quant_matmul_128_2": partial( _quant_matmul, transpose=False, group_size=128, bits=2 ), "quant_matmul_128_4": partial( _quant_matmul, transpose=False, group_size=128, bits=4 ), "quant_matmul_128_8": partial( _quant_matmul, transpose=False, group_size=128, bits=8 ), "quant_matmul_t_32_2": partial( _quant_matmul, transpose=True, group_size=32, bits=2 ), "quant_matmul_t_32_4": partial( _quant_matmul, transpose=True, group_size=32, bits=4 ), "quant_matmul_t_32_8": partial( _quant_matmul, transpose=True, group_size=32, bits=8 ), "quant_matmul_t_64_2": partial( _quant_matmul, transpose=True, group_size=64, bits=2 ), "quant_matmul_t_64_4": partial( _quant_matmul, transpose=True, group_size=64, bits=4 ), "quant_matmul_t_64_8": partial( _quant_matmul, transpose=True, group_size=64, bits=8 ), "quant_matmul_t_128_2": partial( _quant_matmul, transpose=True, group_size=128, bits=2 ), "quant_matmul_t_128_4": partial( _quant_matmul, transpose=True, group_size=128, bits=4 ), "quant_matmul_t_128_8": partial( _quant_matmul, transpose=True, group_size=128, bits=8 ), } def conv1d(x, y): ys = [] for i in range(10): ys.append(mx.conv1d(x, y)) mx.eval(ys) def conv2d(x, y): ys = [] for i in range(10): ys.append(mx.conv2d(x, y)) mx.eval(ys) def binary(op, x, y): for i in range(100): y = getattr(mx, op)(x, y) mx.eval(y) def reduction(op, axis, x): ys = [] for i in range(100): ys.append(getattr(mx, op)(x, axis=axis)) mx.eval(ys) def sum_and_add(axis, x, y): z = x.sum(axis=axis, keepdims=True) for i in range(50): z = (z + y).sum(axis=axis, keepdims=True) mx.eval(z) def softmax(axis, x): ys = [] for i in range(100): ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True)) y = ex / mx.sum(ex, axis=axis, keepdims=True) ys.append(y) mx.eval(ys) def softmax_fused(axis, x): ys = [] for i in range(100): y = mx.softmax(x, axis=axis) ys.append(y) mx.eval(ys) def relu(x): y = x for i in range(100): y = nn.relu(y) mx.eval(y) def leaky_relu(x: mx.array): y = x for i in range(100): y = nn.leaky_relu(y) mx.eval(y) def prelu(x: mx.array): y = x for i in range(100): y = nn.prelu(y, mx.ones(1)) mx.eval(y) def softplus(x: mx.array): y = x for i in range(100): y = nn.softplus(y) mx.eval(y) def mish(x: mx.array): y = x for i in range(100): y = nn.mish(y) mx.eval(y) def leaky_relu(x): y = x for i in range(100): y = nn.leaky_relu(y) mx.eval(y) def elu(x): y = x for i in range(100): y = nn.elu(y) mx.eval(y) def relu6(x): y = x for i in range(100): y = nn.relu6(y) mx.eval(y) def softplus(x): y = x for i in range(100): y = nn.softplus(y) mx.eval(y) def celu(x): y = x for i in range(100): y = nn.celu(y) mx.eval(y) def log_sigmoid(x): y = x for i in range(100): y = nn.log_sigmoid(y) mx.eval(y) def scalar_mult(x): y = x for i in range(100): y = y * (1.0 / (1 + i)) mx.eval(y) def cross_entropy(targets, x): ys = [] for i in range(100): y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis( x, mx.reshape(targets, (-1, 1)), axis=-1 ) ys.append(mx.mean(y)) mx.eval(ys) def logsumexp(axis, x): ys = [] for i in range(100): ys.append(mx.logsumexp(x, axis=axis)) mx.eval(ys) def linear(w, b, x): ys = [] for i in range(10): ys.append(x @ mx.transpose(w, (1, 0)) + b) mx.eval(ys) def linear_fused(w, b, x): ys = [] for i in range(10): ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0)))) mx.eval(ys) def rope(x): *_, N, D = x.shape ys = [] for i in range(10): shape = x.shape x = mx.reshape(x, (-1, N, D)) positions = mx.arange(N) freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1))) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) costheta = mx.cos(theta) sintheta = mx.sin(theta) x1 = x[..., ::2] x2 = x[..., 1::2] rx1 = x1 * costheta - x2 * sintheta rx2 = x1 * sintheta + x2 * costheta y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) y = mx.reshape(y, (-1, N, D)) ys.append(y) mx.eval(ys) def concatenate(axis, x, y): ys = [] for i in range(10): ys.append(mx.concatenate([x, y], axis=axis)) mx.eval(ys) def cumsum(axis, x): ys = [] for i in range(10): ys.append(mx.cumsum(x, axis)) mx.eval(ys) def sort(axis, x): ys = [] for i in range(10): ys.append(mx.sort(x, axis)) mx.eval(ys) def topk(axis, x): k = x.shape[axis] // 3 ys = [] for i in range(10): ys.append(mx.topk(x, k, axis)) mx.eval(ys) def step_function(x): y = x for i in range(100): y = nn.step(x) mx.eval(y) def selu(x): y = x for i in range(100): y = nn.selu(x) mx.eval(y) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", help="Choose the benchmark to run") parser.add_argument( "--size", default=[(1024, 1024)], type=lambda x: list(map(int, x.split("x"))), help="Set the matrix size", action="append", ) parser.add_argument( "--axis", default=[1], type=int_or_list, help="Set a reduction axis", action="append", ) parser.add_argument( "--transpose", type=none_or_list, default=[], help="Permute the matrix", action="append", ) parser.add_argument( "--print-pid", action="store_true", help="Print the PID and pause" ) parser.add_argument("--cpu", action="store_true", help="Use the CPU") parser.add_argument( "--fused", action="store_true", help="Use fused functions where possible" ) parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") args = parser.parse_args() if len(args.size) > 1: args.size.pop(0) if len(args.axis) > 1: args.axis.pop(0) if args.cpu: mx.set_default_device(mx.cpu) else: mx.set_default_device(mx.gpu) types = args.dtype if not types: types = [mx.float32] if len(types) < len(args.size): types = types + [types[0]] * (len(args.size) - len(types)) xs = [] for size, dtype in zip(args.size, types): xs.append(mx.random.normal(size).astype(dtype)) for i, t in enumerate(args.transpose): if t is None: continue xs[i] = mx.transpose(xs[i], t) mx.eval(xs) x = xs[0] axis = args.axis[0] if args.print_pid: print(os.getpid()) input("Press enter to run") if args.benchmark == "matmul_square": print(bench(matmul_square, x)) elif args.benchmark == "matmul": print(bench(matmul, *xs)) elif args.benchmark.startswith("quant_matmul"): print(bench(quant_matmul[args.benchmark], *xs)) elif args.benchmark == "linear": if args.fused: print(bench(linear_fused, *xs)) else: print(bench(linear, *xs)) elif args.benchmark == "sum_axis": print(bench(reduction, "sum", axis, x)) elif args.benchmark == "sum_all": print(bench(reduction, "sum", None, x)) elif args.benchmark == "argmax": print(bench(reduction, "argmax", axis, x)) elif args.benchmark == "add": print(bench(binary, "add", *xs)) elif args.benchmark == "mul": print(bench(binary, "multiply", *xs)) elif args.benchmark == "softmax": if args.fused: print(bench(softmax_fused, axis, x)) else: print(bench(softmax, axis, x)) elif args.benchmark == "relu": print(bench(relu, x)) elif args.benchmark == "elu": print(bench(elu, x)) elif args.benchmark == "relu6": print(bench(relu6, x)) elif args.benchmark == "celu": print(bench(celu, x)) elif args.benchmark == "log_sigmoid": print(bench(log_sigmoid, x)) elif args.benchmark == "leaky_relu": print(bench(leaky_relu, x)) elif args.benchmark == "prelu": print(bench(prelu, x)) elif args.benchmark == "softplus": print(bench(softplus, x)) elif args.benchmark == "mish": print(bench(mish, x)) elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) elif args.benchmark == "cross_entropy": if len(size) != 2: raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") targets = mx.zeros((len(x),), dtype=mx.uint32) print(bench(cross_entropy, targets, x)) elif args.benchmark == "logsumexp": print(bench(logsumexp, axis, x)) elif args.benchmark == "rope": print(bench(rope, x)) elif args.benchmark == "concatenate": print(bench(concatenate, axis, *xs)) elif args.benchmark == "cumsum": print(bench(cumsum, axis, *xs)) elif args.benchmark == "conv1d": print(bench(conv1d, *xs)) elif args.benchmark == "conv2d": print(bench(conv2d, *xs)) elif args.benchmark == "sort": print(bench(sort, axis, x)) elif args.benchmark == "topk": print(bench(topk, axis, x)) elif args.benchmark == "step": print(bench(step_function, x)) elif args.benchmark == "selu": print(bench(selu, x)) elif args.benchmark == "sum_and_add": print(bench(sum_and_add, axis, *xs)) else: raise ValueError("Unknown benchmark") ================================================ FILE: benchmarks/python/comparative/bench_torch.py ================================================ # Copyright © 2023 Apple Inc. import argparse import os import time import torch import torch.cuda import torch.mps def int_or_list(x): try: return int(x) except ValueError: return [int(xi) for xi in x.split(",")] def none_or_list(x): if x == "": return None else: return [int(xi) for xi in x.split(",")] def dtype_from_str(x): if x == "": return torch.float32 else: dt = getattr(torch, x) if not isinstance(dt, torch.dtype): raise ValueError(f"{x} is not a torch dtype") return dt def bench(f, *args): for i in range(10): f(*args) s = time.perf_counter() for i in range(100): f(*args) e = time.perf_counter() return e - s def sync_if_needed(x): if x.device == torch.device("mps"): torch.mps.synchronize() elif x.device == torch.device("cuda"): torch.cuda.synchronize() @torch.no_grad() def matmul_square(x): y = x for i in range(10): y = y @ x sync_if_needed(x) @torch.no_grad() def matmul(x, y): ys = [] for i in range(10): ys.append(x @ y) sync_if_needed(x) @torch.no_grad() def conv1d(x, y): x = torch.transpose(x, -1, -2) y = torch.transpose(y, -1, -2) ys = [] for i in range(10): ys.append(torch.nn.functional.conv1d(x, y)) sync_if_needed(x) @torch.no_grad() def conv2d(x, y): x = torch.permute(x, (0, 3, 1, 2)) y = torch.permute(y, (0, 3, 1, 2)) ys = [] for i in range(10): ys.append(torch.nn.functional.conv2d(x, y)) sync_if_needed(x) @torch.no_grad() def binary(op, x, y): for i in range(100): y = getattr(torch, op)(x, y) sync_if_needed(x) @torch.no_grad() def reduction(op, axis, x): ys = [] for i in range(100): ys.append(getattr(x, op)(axis)) sync_if_needed(x) @torch.no_grad() def sum_and_add(axis, x, y): z = x.sum(axis=axis, keepdims=True) for i in range(50): z = (z + y).sum(axis=axis, keepdims=True) sync_if_needed(x) @torch.no_grad() def softmax(axis, x): ys = [] for i in range(100): ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values) y = ex / torch.sum(ex, dim=axis, keepdims=True) ys.append(y) sync_if_needed(x) @torch.no_grad() def softmax_fused(axis, x): ys = [] for i in range(100): ys.append(torch.nn.functional.softmax(x, dim=axis)) sync_if_needed(x) @torch.no_grad() def relu(x): y = x for i in range(100): y = torch.nn.functional.relu(y) sync_if_needed(x) @torch.no_grad() def leaky_relu(x): y = x for i in range(100): y = torch.nn.functional.leaky_relu(y) sync_if_needed(x) @torch.no_grad() def elu(x): y = x for i in range(100): y = torch.nn.functional.elu(y) sync_if_needed(x) @torch.no_grad() def celu(x): y = x for i in range(100): y = torch.nn.functional.celu(y) sync_if_needed(x) @torch.no_grad() def relu6(x): y = x for i in range(100): y = torch.nn.functional.relu6(y) sync_if_needed(x) @torch.no_grad() def softplus(x): y = x for i in range(100): y = torch.nn.functional.softplus(y) sync_if_needed(x) @torch.no_grad() def log_sigmoid(x): y = x for i in range(100): y = torch.nn.functional.logsigmoid(y) sync_if_needed(x) @torch.no_grad() def prelu(x: torch.Tensor) -> torch.Tensor: y = x for _ in range(100): y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device)) sync_if_needed(x) @torch.no_grad() def mish(x: torch.Tensor) -> torch.Tensor: y = x for _ in range(100): y = torch.nn.functional.mish(y) sync_if_needed(x) @torch.no_grad() def scalar_mult(x): y = x for i in range(100): y = y * (1.0 / (1 + i)) sync_if_needed(x) @torch.no_grad() def cross_entropy(targets, x): ys = [] for i in range(100): ys.append(torch.nn.functional.cross_entropy(x, targets)) sync_if_needed(x) @torch.no_grad() def logsumexp(axis, x): ys = [] for i in range(100): ys.append(torch.logsumexp(x, dim=axis)) sync_if_needed(x) @torch.no_grad() def linear_fused(w, b, x): ys = [] for i in range(10): ys.append(torch.nn.functional.linear(x, w, b)) sync_if_needed(x) @torch.no_grad() def linear(w, b, x): ys = [] for i in range(10): ys.append((x @ torch.transpose(w, -2, -1)) + b) sync_if_needed(x) @torch.no_grad() def rope(x): *_, N, D = x.shape ys = [] for i in range(10): x = x.view(-1, N, D) positions = torch.arange(N, device=x.device) freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device) theta = positions[:, None] * freqs[None] costheta = torch.cos(theta) sintheta = torch.sin(theta) x1 = x[..., ::2] x2 = x[..., 1::2] rx1 = x1 * costheta - x2 * sintheta rx2 = x1 * sintheta + x2 * costheta y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) y = y.reshape(-1, N, D) ys.append(y) sync_if_needed(x) @torch.no_grad() def concatenate(axis, x, y): ys = [] for i in range(10): ys.append(torch.cat([x, y], dim=axis)) sync_if_needed(x) @torch.no_grad() def cumsum(axis, x): ys = [] for i in range(10): ys.append(x.cumsum(axis)) sync_if_needed(x) @torch.no_grad() def sort(axis, x): ys = [] for i in range(10): ys.append(torch.sort(x, dim=axis)[0]) sync_if_needed(x) @torch.no_grad() def topk(axis, x): k = x.shape[axis] // 3 ys = [] for i in range(10): ys.append(torch.topk(x, k, dim=axis)[0]) sync_if_needed(x) @torch.no_grad() def step_function(x): y = x for i in range(100): y = torch.where(y < 0, 0, 1) sync_if_needed(x) @torch.no_grad() def selu(x): y = x for i in range(100): y = torch.nn.functional.selu(y) sync_if_needed(x) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", help="Choose the benchmark to run") parser.add_argument( "--size", default=[(1024, 1024)], type=lambda x: list(map(int, x.split("x"))), help="Set the matrix size", action="append", ) parser.add_argument( "--axis", default=[1], type=int_or_list, help="Set a reduction axis", action="append", ) parser.add_argument( "--transpose", type=none_or_list, default=[], help="Permute the matrix", action="append", ) parser.add_argument( "--print-pid", action="store_true", help="Print the PID and pause" ) parser.add_argument("--cpu", action="store_true", help="Use the CPU") parser.add_argument( "--fused", action="store_true", help="Use fused functions where possible" ) parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") args = parser.parse_args() if len(args.size) > 1: args.size.pop(0) if len(args.axis) > 1: args.axis.pop(0) torch.set_num_threads(1) device = "mps" if torch.cuda.is_available(): device = "cuda" if args.cpu: device = "cpu" types = args.dtype if not types: types = [torch.float32] if len(types) < len(args.size): types = types + [types[0]] * (len(args.size) - len(types)) xs = [] for size, dtype in zip(args.size, types): xs.append(torch.randn(*size).to(device).to(dtype)) for i, t in enumerate(args.transpose): if t is None: continue xs[i] = xs[i].permute(*t) x = xs[0] axis = args.axis[0] if args.print_pid: print(os.getpid()) input("Press enter to run") if args.benchmark == "matmul_square": print(bench(matmul_square, x)) elif args.benchmark == "matmul": print(bench(matmul, *xs)) elif args.benchmark == "linear": if args.fused: print(bench(linear_fused, *xs)) else: print(bench(linear, *xs)) elif args.benchmark == "sum_axis": print(bench(reduction, "sum", axis, x)) elif args.benchmark == "sum_all": print(bench(reduction, "sum", None, x)) elif args.benchmark == "argmax": print(bench(reduction, "argmax", axis, x)) elif args.benchmark == "add": print(bench(binary, "add", *xs)) elif args.benchmark == "mul": print(bench(binary, "mul", *xs)) elif args.benchmark == "softmax": if args.fused: print(bench(softmax_fused, axis, x)) else: print(bench(softmax, axis, x)) elif args.benchmark == "relu": print(bench(relu, x)) elif args.benchmark == "leaky_relu": print(bench(leaky_relu, x)) elif args.benchmark == "elu": print(bench(elu, x)) elif args.benchmark == "relu6": print(bench(relu6, x)) elif args.benchmark == "softplus": print(bench(softplus, x)) elif args.benchmark == "celu": print(bench(celu, x)) elif args.benchmark == "log_sigmoid": print(bench(log_sigmoid, x)) elif args.benchmark == "prelu": print(bench(prelu, x)) elif args.benchmark == "mish": print(bench(mish, x)) elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) elif args.benchmark == "cross_entropy": if len(size) != 2: raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") targets = torch.zeros(len(x), dtype=torch.long).to(x.device) print(bench(cross_entropy, targets, x)) elif args.benchmark == "logsumexp": print(bench(logsumexp, axis, x)) elif args.benchmark == "rope": print(bench(rope, x)) elif args.benchmark == "concatenate": print(bench(concatenate, axis, *xs)) elif args.benchmark == "cumsum": print(bench(cumsum, axis, *xs)) elif args.benchmark == "conv1d": print(bench(conv1d, *xs)) elif args.benchmark == "conv2d": print(bench(conv2d, *xs)) elif args.benchmark == "sort": print(bench(sort, axis, x)) elif args.benchmark == "topk": print(bench(topk, axis, x)) elif args.benchmark == "step": print(bench(step_function, x)) elif args.benchmark == "selu": print(bench(selu, x)) elif args.benchmark == "sum_and_add": print(bench(sum_and_add, axis, *xs)) else: raise ValueError(f"Unknown benchmark `{args.benchmark}`.") ================================================ FILE: benchmarks/python/comparative/compare.py ================================================ # Copyright © 2023 Apple Inc. #!/usr/bin/env python import argparse import re from pathlib import Path from subprocess import run BENCH_MLX = Path(__file__).parent / "bench_mlx.py" BENCH_TORCH = Path(__file__).parent / "bench_torch.py" def run_or_raise(*args, **kwargs): try: result = run(*args, capture_output=True, **kwargs) return float(result.stdout) except ValueError: raise ValueError( f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}" ) def compare(args): t_mlx = run_or_raise(["python", BENCH_MLX] + args) t_torch = run_or_raise(["python", BENCH_TORCH] + args) print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t") def compare_mlx_dtypes(args, dt1, dt2): t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1]) t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2]) print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t") def make_regex_search(regexes): compiled_regexes = list(map(re.compile, regexes)) def search(x): return (c.search(x) is not None for c in compiled_regexes) return search def make_predicate(positive_filter, negative_filter): if positive_filter is not None: positive_filter_search = make_regex_search(positive_filter) positive_filter = lambda x: all(positive_filter_search(x)) else: positive_filter = lambda x: True if negative_filter is not None: negative_filter_search = make_regex_search(negative_filter) negative_filter = lambda x: not any(negative_filter_search(x)) else: negative_filter = lambda x: True def predicate(x): return positive_filter(x) and negative_filter(x) return predicate if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run comparisons against PyTorch") parser.add_argument( "--filter", "-f", help="Regex filter to select benchmarks", nargs="+" ) parser.add_argument( "--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+" ) parser.add_argument( "--mlx_dtypes", "-d", help="Compare mlx benchmarks between the 2 provided data types", nargs=2, ) args, rest = parser.parse_known_args() _filter = make_predicate(args.filter, args.negative_filter) if args.mlx_dtypes: compare_filtered = lambda x: ( compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) if _filter(x) else None ) else: compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None # Binary ops compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu") compare_filtered("add --size 10x1024x128 --size 1x1024x128") compare_filtered("add --size 1024x128 --size 1x128 --cpu") compare_filtered("add --size 1024x128 --size 1x128") compare_filtered("add --size 1024x4096 --size 1x4096 --cpu") compare_filtered("add --size 1024x4096 --size 1x4096") compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu") compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0") compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu") compare_filtered("add --size 1024x1024 --size 1024x1024") compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu") compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0") compare_filtered( "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu" ) compare_filtered( "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0" ) # Reduction ops compare_filtered("sum_all --size 10x1024x128 --cpu") compare_filtered("sum_all --size 10x1024x128") compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu") compare_filtered("sum_axis --size 16x1024x128 --axis 2") compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 2") compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu") compare_filtered("sum_axis --size 1024x1024 --axis 1") compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu") compare_filtered("sum_axis --size 1024x1024 --axis 0") compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 1") compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0") compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0,1") compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0,2") compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1") compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1") compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu") compare_filtered("argmax --size 10x1024x128 --axis 1") compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu") compare_filtered("argmax --size 10x1024x128 --axis 2") compare_filtered("argmax --size 1024x1024 --axis 1 --cpu") compare_filtered("argmax --size 1024x1024 --axis 1") # Matmul ops compare_filtered("matmul_square --size 1024x1024") compare_filtered("matmul_square --size 1024x1024 --cpu") compare_filtered("matmul_square --size 16x1024x1024") compare_filtered("matmul_square --size 16x1024x1024 --cpu") compare_filtered( "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1" ) compare_filtered( "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu" ) compare_filtered( "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1" ) compare_filtered( "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu" ) compare_filtered("matmul --size 512x8192 --size 8192x512") compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu") # compare_filtered("matmul --size 512x131072 --size 131072x512") # compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu") compare_filtered("matmul --size 8192x512 --size 512x8192") compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu") # compare_filtered("matmul --size 131072x512 --size 512x512") # compare_filtered("matmul --size 131072x512 --size 512x512 --cpu") compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024") compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu") compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused") compare_filtered( "linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu" ) # Matvec ops compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu") compare_filtered("matmul --size 1x1x4096 --size 4096x4096") compare_filtered( "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu" ) compare_filtered( "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0" ) compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu") compare_filtered("matmul --size 32x1x1000 --size 32x1000x128") compare_filtered( "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu" ) compare_filtered( "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1" ) # Various ops compare_filtered("softmax --size 32x16x1024 --axis 2") compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu") compare_filtered("softmax --size 32x16x1024 --axis 2 --fused") compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu") compare_filtered("softmax --size 2x1024x1024 --axis 1") compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu") compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused") compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu") compare_filtered("relu --size 32x16x1024") compare_filtered("relu --size 32x16x1024 --cpu") compare_filtered("leaky_relu --size 32x16x1024") compare_filtered("leaky_relu --size 32x16x1024 --cpu") compare_filtered("elu --size 32x16x1024") compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024") compare_filtered("celu --size 32x16x1024 --cpu") compare_filtered("log_sigmoid --size 32x16x1024") compare_filtered("log_sigmoid --size 32x16x1024 --cpu") compare_filtered("step --size 32x16x1024") compare_filtered("step --size 32x16x1024 --cpu") compare_filtered("selu --size 32x16x1024") compare_filtered("selu --size 32x16x1024 --cpu") # compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm compare_filtered("mish --size 32x16x1024 --cpu") compare_filtered("prelu --size 32x16x1024") compare_filtered("prelu --size 32x16x1024 --cpu") compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("cross_entropy --size 256x1024") compare_filtered("cross_entropy --size 256x1024 --cpu") compare_filtered("logsumexp --size 1024x1024 --axis 1") compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu") compare_filtered("logsumexp --size 1024x1024 --axis 0") compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0") compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu") compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1") compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu") compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1") compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu") compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2") compare_filtered( "concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu" ) compare_filtered("conv1d --size 1x1000x80 --size 128x11x80") compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu") compare_filtered("conv1d --size 16x1000x80 --size 128x11x80") compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu") compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3") compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu") compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3") compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu") compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu") compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu") compare_filtered("cumsum --size 1024x1024 --axis 1") compare_filtered("cumsum --size 1024x1024 --axis 0") compare_filtered("cumsum --size 128x1024 --axis 1") compare_filtered("cumsum --size 128x1024 --axis 0") compare_filtered("cumsum --size 1024x4096 --axis 1") compare_filtered("cumsum --size 1024x4096 --axis 0") compare_filtered("cumsum --size 128x4096 --axis 1") compare_filtered("cumsum --size 128x4096 --axis 0") compare_filtered("cumsum --size 1024x7777 --axis 1") compare_filtered("cumsum --size 1024x7777 --axis 0") compare_filtered("cumsum --size 128x7777 --axis 1") compare_filtered("cumsum --size 128x7777 --axis 0") compare_filtered("cumsum --size 32768x128 --axis 1") compare_filtered("cumsum --size 32768x128 --axis 0") compare_filtered("sort --size 1024x1024 --axis 0") compare_filtered("sort --size 1024x1024 --axis 1") compare_filtered("sort --size 32768x128 --axis 0") compare_filtered("sort --size 32768x128 --axis 1") compare_filtered("sort --size 128x128 --axis 0 --cpu") compare_filtered("sort --size 128x128 --axis 1 --cpu") compare_filtered("topk --size 1024x1024 --axis 0") compare_filtered("topk --size 1024x1024 --axis 1") compare_filtered("topk --size 32768x128 --axis 0") compare_filtered("topk --size 32768x128 --axis 1") compare_filtered("topk --size 128x128 --axis 0 --cpu") compare_filtered("topk --size 128x128 --axis 1 --cpu") ================================================ FILE: benchmarks/python/compile_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import argparse import math import random import mlx.core as mx from time_utils import time_fn def bench_gelu(): def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 x = mx.random.uniform(shape=(1000, 1024)) def gen_fun(fun): def bench_fun(x): for _ in range(10): x = fun(x) return x return bench_fun time_fn(gen_fun(gelu), x, msg="fixed gelu") time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") def randint(): return random.randint(1, x.shape[0]) def gen_fun(fun): def bench_fun(x, y): x = x[: randint()] for _ in range(10): x = fun(x) y = fun(y) return x, y return bench_fun y = mx.random.uniform(shape=(1000, 1024)) time_fn(gen_fun(gelu), x, y, msg="variable gelu") time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") time_fn( gen_fun(mx.compile(gelu, shapeless=True)), x, y, msg="shapeless variable gelu", ) def bench_layernorm(): weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) mx.eval(weight, bias) def layernorm(x): x = x.astype(mx.float32) means = mx.mean(x, axis=-1, keepdims=True) var = mx.var(x, axis=-1, keepdims=True) x = (x - means) * mx.rsqrt(var + 1e-4) x = x.astype(mx.float16) return weight * x + bias x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) def gen_fun(fun): def bench_fun(x): for _ in range(10): x = fun(x) return x return bench_fun time_fn(gen_fun(layernorm), x, msg="fixed layernorm") time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") def randint(): return random.randint(1, x.shape[0]) def gen_fun(fun): def bench_fun(x): x = x[: randint()] for _ in range(10): x = fun(x) return x return bench_fun random.seed(0) time_fn(gen_fun(layernorm), x, msg="variable layernorm") random.seed(0) time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") random.seed(0) time_fn( gen_fun(mx.compile(layernorm, shapeless=True)), x, msg="shapeless variable layernorm", ) if __name__ == "__main__": parser = argparse.ArgumentParser("Compile benchmarks.") args = parser.parse_args() bench_gelu() bench_layernorm() ================================================ FILE: benchmarks/python/conv1d_bench.py ================================================ import argparse import math import os import subprocess import time import mlx.core as mx import numpy as np import torch device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") N_warmup = 10 N_iter_bench = 100 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_1D(strides=1, padding=0, groups=1): def mx_conv_1D(a, b): ys = [] for _ in range(N_iter_func): y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys return mx_conv_1D def make_pt_conv_1D(strides=1, padding=0, groups=1): @torch.no_grad() def pt_conv_1D(a, b): ys = [] for _ in range(N_iter_func): y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) torch.mps.synchronize() return ys return pt_conv_1D def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups): scale = 1.0 / math.sqrt(wH * C) a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps") b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps") torch.mps.synchronize() f_mx = make_mx_conv_1D(strides, padding, groups) f_pt = make_pt_conv_1D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv1d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 32, 32, 5, 32, 1, 2, 1), (4, 32, 32, 5, 32, 1, 2, 2), (4, 32, 32, 5, 32, 1, 2, 4), (4, 32, 32, 5, 32, 1, 2, 8), (4, 32, 32, 5, 32, 1, 2, 8), (4, 32, 32, 5, 32, 1, 2, 16), (4, 32, 32, 5, 32, 1, 2, 32), (4, 32, 256, 5, 512, 1, 2, 2), (4, 32, 256, 5, 512, 1, 2, 128), (4, 32, 256, 5, 512, 1, 2, 256), ) for dtype in dtypes: print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%") for N, iH, C, wH, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, iH, C, wH, O, strides, padding, np_dtype, groups ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv2d_bench_cpu.py ================================================ import argparse import math import time import mlx.core as mx import numpy as np import torch N_warmup = 1 N_iter_bench = 10 N_iter_func = 5 mx.set_default_device(mx.cpu) def bench(f, a, b): for i in range(N_warmup): f(a, b) s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys return mx_conv_2D def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) return ys return pt_conv_2D def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu") f_mx = make_mx_conv_2D(strides, padding, groups) f_pt = make_pt_conv_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv2d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2), # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16), # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64), (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), ) for dtype in dtypes: print( "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv2d_train_bench_cpu.py ================================================ import time import mlx.core as mx import mlx.nn import mlx.optimizers as opt import torch def bench_mlx(steps: int = 20) -> float: mx.set_default_device(mx.cpu) class BenchNetMLX(mlx.nn.Module): # simple encoder-decoder net def __init__(self, in_channels, hidden_channels=32): super().__init__() self.net = mlx.nn.Sequential( mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), mlx.nn.ReLU(), mlx.nn.Conv2d( hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 ), mlx.nn.ReLU(), mlx.nn.ConvTranspose2d( 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 ), mlx.nn.ReLU(), mlx.nn.ConvTranspose2d( hidden_channels, in_channels, kernel_size=3, padding=1 ), ) def __call__(self, input): return self.net(input) benchNet = BenchNetMLX(3) mx.eval(benchNet.parameters()) optim = opt.Adam(learning_rate=1e-3) inputs = mx.random.normal([10, 256, 256, 3]) params = benchNet.parameters() optim.init(params) state = [benchNet.state, optim.state] def loss_fn(params, image): benchNet.update(params) pred_image = benchNet(image) return (pred_image - image).abs().mean() def step(params, image): loss, grads = mx.value_and_grad(loss_fn)(params, image) optim.update(benchNet, grads) return loss total_time = 0.0 print("MLX:") for i in range(steps): start_time = time.perf_counter() step(benchNet.parameters(), inputs) mx.eval(state) end_time = time.perf_counter() print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") total_time += (end_time - start_time) * 1000 return total_time def bench_torch(steps: int = 20) -> float: device = torch.device("cpu") class BenchNetTorch(torch.nn.Module): # simple encoder-decoder net def __init__(self, in_channels, hidden_channels=32): super().__init__() self.net = torch.nn.Sequential( torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d( hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 ), torch.nn.ReLU(), torch.nn.ConvTranspose2d( 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 ), torch.nn.ReLU(), torch.nn.ConvTranspose2d( hidden_channels, in_channels, kernel_size=3, padding=1 ), ) def forward(self, input): return self.net(input) benchNet = BenchNetTorch(3).to(device) optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) inputs = torch.randn(10, 3, 256, 256, device=device) def loss_fn(pred_image, image): return (pred_image - image).abs().mean() total_time = 0.0 print("PyTorch:") for i in range(steps): start_time = time.perf_counter() optim.zero_grad() pred_image = benchNet(inputs) loss = loss_fn(pred_image, inputs) loss.backward() optim.step() end_time = time.perf_counter() print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") total_time += (end_time - start_time) * 1000 return total_time def main(): steps = 20 time_mlx = bench_mlx(steps) time_torch = bench_torch(steps) print(f"average time of MLX: {time_mlx/steps:9.2f} ms") print(f"total time of MLX: {time_mlx:9.2f} ms") print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") print(f"total time of PyTorch: {time_torch:9.2f} ms") diff = time_torch / time_mlx - 1.0 print(f"torch/mlx diff: {100. * diff:+5.2f}%") if __name__ == "__main__": main() ================================================ FILE: benchmarks/python/conv2d_transpose_bench_cpu.py ================================================ import argparse import math import time import mlx.core as mx import numpy as np import torch N_warmup = 1 N_iter_bench = 10 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_transpose_2D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv_transpose2d( a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu ) ys.append(y) mx.eval(ys) return ys return mx_conv_transpose_2D def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_transpose_2D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv_transpose2d( a, b, stride=strides, padding=padding, groups=groups ) ys.append(y) return ys return pt_conv_transpose_2D def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu") f_mx = make_mx_conv_transpose_2D(strides, padding, groups) f_pt = make_pt_conv_transpose_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv_transpose2d( a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu ) out_pt = torch.conv_transpose2d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), ) for dtype in dtypes: print( "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv3d_bench.py ================================================ import math import time import mlx.core as mx import numpy as np import torch N_warmup = 2 N_iter_bench = 10 N_iter_func = 10 def bench(f, a, b, b_prime): for i in range(N_warmup): f(a, b, b_prime) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b, b_prime) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): def mx_conv_3D(a, b, b_prime): y = a for i in range(N_iter_func): y = mx.conv3d(y, b, stride=strides, padding=padding, groups=groups) y = mx.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups) mx.eval(y) return y return mx_conv_3D def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): @torch.no_grad() def pt_conv_3D(a, b, b_prime): y = a for i in range(N_iter_func): y = torch.conv3d(y, b, stride=strides, padding=padding, groups=groups) y = torch.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups) torch.mps.synchronize() return y return pt_conv_3D def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kD * kH * kW * C) a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)) b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))) b_prime_np = np.random.uniform(-scale, scale, (C, kD, kH, kW, int(O / groups))) a_np, b_np, b_prime_np = map(lambda x: x.astype(np_dtype), (a_np, b_np, b_prime_np)) a_mx, b_mx, b_prime_mx = map(lambda x: mx.array(x), (a_np, b_np, b_prime_np)) a_pt, b_pt, b_prime_pt = map( lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("mps"), (a_np, b_np, b_prime_np), ) torch.mps.synchronize() f_mx = make_mx_conv_3D(strides, padding, groups) f_pt = make_pt_conv_3D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt, b_prime_pt) time_mlx = bench(f_mx, a_mx, b_mx, b_prime_mx) # Measure MLX memory mx.clear_cache() mx.reset_peak_memory() y = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) mx.eval(y) mlx_peak_mb = mx.get_peak_memory() / 1024**2 mlx_active_mb = mx.get_active_memory() / 1024**2 del y # Measure PyTorch MPS memory torch.mps.synchronize() torch.mps.empty_cache() y = torch.conv3d(a_pt, b_pt, stride=strides, padding=padding, groups=groups) torch.mps.synchronize() pt_current_mb = torch.mps.current_allocated_memory() / 1024**2 pt_driver_mb = torch.mps.driver_allocated_memory() / 1024**2 del y out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv3d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 5e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} " f"[strides = {strides}, padding = {padding}, groups = {groups}] " f"with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch, mlx_peak_mb, mlx_active_mb, pt_current_mb, pt_driver_mb if __name__ == "__main__": dtypes = ("float16", "float32") shapes = ( # (C % 16 == 0) (4, 16, 16, 16, 32, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1), (4, 16, 16, 16, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1), (4, 16, 16, 16, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1), (4, 32, 32, 32, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1), (4, 32, 32, 32, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1), # Larger spatial dims (2, 64, 64, 64, 32, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1), (1, 64, 64, 64, 64, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1), # Strided (4, 32, 32, 32, 64, 3, 3, 3, 128, (2, 2, 2), (1, 1, 1), 1), # Asymmetric kernels (4, 32, 32, 32, 64, 3, 1, 1, 128, (1, 1, 1), (1, 0, 0), 1), (4, 32, 32, 32, 64, 1, 3, 3, 128, (1, 1, 1), (0, 1, 1), 1), # (C % 16 != 0) (4, 16, 16, 16, 21, 3, 3, 3, 21, (1, 1, 1), (1, 1, 1), 1), (4, 16, 16, 16, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1), (4, 32, 32, 32, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1), (4, 16, 16, 16, 3, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1), ) for dtype in dtypes: print(f"\n{'=' * 120}" f"\n dtype: {dtype}" f"\n{'=' * 120}") print( f"{'(N, D, H, W, C)':<26s} {'( O, kD, kH, kW, C)':<24s} " f"{'stride':<12s} {'pads':<12s} {'groups':>6s} " f"{'diff%':>7s} " f"{'MLX peak':>9s} {'MLX act':>8s} {'PT cur':>8s} {'PT drv':>8s}" ) for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch, mlx_peak, mlx_act, pt_cur, pt_drv = bench_shape( N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), " f"{strides}, {padding}, {groups:6d}, " f"{100. * diff:+6.1f}% " f"{mlx_peak:8.1f} {mlx_act:7.1f} {pt_cur:7.1f} {pt_drv:7.1f}" ) ================================================ FILE: benchmarks/python/conv3d_bench_cpu.py ================================================ import argparse import math import time import mlx.core as mx import numpy as np import torch N_warmup = 1 N_iter_bench = 10 N_iter_func = 5 mx.set_default_device(mx.cpu) def bench(f, a, b): for i in range(N_warmup): f(a, b) s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_3D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys return mx_conv_3D def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_3D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) return ys return pt_conv_3D def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kD * kH * kW * C) a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu") f_mx = make_mx_conv_3D(strides, padding, groups) f_pt = make_pt_conv_3D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv3d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), ) for dtype in dtypes: print( "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv3d_train_bench_cpu.py ================================================ import time import mlx.core as mx import mlx.nn import mlx.optimizers as opt import torch def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float: mx.set_default_device(mx.cpu) class BenchNetMLX(mlx.nn.Module): # simple encoder-decoder net def __init__(self, in_channels, hidden_channels=16): super().__init__() self.net = mlx.nn.Sequential( mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), mlx.nn.ReLU(), mlx.nn.Conv3d( hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 ), mlx.nn.ReLU(), mlx.nn.ConvTranspose3d( 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 ), mlx.nn.ReLU(), mlx.nn.ConvTranspose3d( hidden_channels, in_channels, kernel_size=3, padding=1 ), ) def __call__(self, input): return self.net(input) benchNet = BenchNetMLX(3) mx.eval(benchNet.parameters()) optim = opt.Adam(learning_rate=1e-3) inputs = mx.random.normal(shape) params = benchNet.parameters() optim.init(params) state = [benchNet.state, optim.state] def loss_fn(params, image): benchNet.update(params) pred_image = benchNet(image) return (pred_image - image).abs().mean() def step(params, image): loss, grads = mx.value_and_grad(loss_fn)(params, image) optim.update(benchNet, grads) return loss total_time = 0.0 print("MLX:") for i in range(steps): start_time = time.perf_counter() step(benchNet.parameters(), inputs) mx.eval(state) end_time = time.perf_counter() print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") total_time += (end_time - start_time) * 1000 return total_time def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float: device = torch.device("cpu") class BenchNetTorch(torch.nn.Module): # simple encoder-decoder net def __init__(self, in_channels, hidden_channels=16): super().__init__() self.net = torch.nn.Sequential( torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.Conv3d( hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 ), torch.nn.ReLU(), torch.nn.ConvTranspose3d( 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 ), torch.nn.ReLU(), torch.nn.ConvTranspose3d( hidden_channels, in_channels, kernel_size=3, padding=1 ), ) def forward(self, input): return self.net(input) benchNet = BenchNetTorch(3).to(device) optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) inputs = torch.randn(*shape, device=device) def loss_fn(pred_image, image): return (pred_image - image).abs().mean() total_time = 0.0 print("PyTorch:") for i in range(steps): start_time = time.perf_counter() optim.zero_grad() pred_image = benchNet(inputs) loss = loss_fn(pred_image, inputs) loss.backward() optim.step() end_time = time.perf_counter() print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") total_time += (end_time - start_time) * 1000 return total_time def main(): steps = 10 time_mlx = bench_mlx(steps) time_torch = bench_torch(steps) print(f"average time of MLX: {time_mlx/steps:9.2f} ms") print(f"total time of MLX: {time_mlx:9.2f} ms") print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") print(f"total time of PyTorch: {time_torch:9.2f} ms") diff = time_torch / time_mlx - 1.0 print(f"torch/mlx diff: {100. * diff:+5.2f}%") if __name__ == "__main__": main() ================================================ FILE: benchmarks/python/conv3d_transpose_bench_cpu.py ================================================ import argparse import math import time import mlx.core as mx import numpy as np import torch N_warmup = 1 N_iter_bench = 10 N_iter_func = 5 mx.set_default_device(mx.cpu) def bench(f, a, b): for i in range(N_warmup): f(a, b) s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): def mx_conv_3D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv_transpose3d( a, b, stride=strides, padding=padding, groups=groups ) ys.append(y) mx.eval(ys) return ys return mx_conv_3D def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): @torch.no_grad() def pt_conv_3D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv_transpose3d( a, b, stride=strides, padding=padding, groups=groups ) ys.append(y) return ys return pt_conv_3D def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kD * kH * kW * C) a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu") f_mx = make_mx_conv_3D(strides, padding, groups) f_pt = make_pt_conv_3D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv_transpose3d( a_mx, b_mx, stride=strides, padding=padding, groups=groups ) out_pt = torch.conv_transpose3d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), ) for dtype in dtypes: print( "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv_bench.py ================================================ import argparse import math import os import subprocess import time import mlx.core as mx import numpy as np import torch device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") N_warmup = 10 N_iter_bench = 100 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys return mx_conv_2D def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) torch.mps.synchronize() return ys return pt_conv_2D def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") torch.mps.synchronize() f_mx = make_mx_conv_2D(strides, padding, groups) f_pt = make_pt_conv_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv2d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64), (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), ) for dtype in dtypes: print( "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv_transpose_bench.py ================================================ import argparse import math import os import subprocess import time import mlx.core as mx import numpy as np import torch N_warmup = 10 N_iter_bench = 100 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_transpose_2D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv_transpose2d( a, b, stride=strides, padding=padding, groups=groups ) ys.append(y) mx.eval(ys) return ys return mx_conv_transpose_2D def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_transpose_2D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv_transpose2d( a, b, stride=strides, padding=padding, groups=groups ) ys.append(y) torch.mps.synchronize() return ys return pt_conv_transpose_2D def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps") torch.mps.synchronize() f_mx = make_mx_conv_transpose_2D(strides, padding, groups) f_pt = make_pt_conv_transpose_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv_transpose2d( a_mx, b_mx, stride=strides, padding=padding, groups=groups ) out_pt = torch.conv_transpose2d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run conv benchmarks") dtypes = ("float32",) shapes = ( (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), ) for dtype in dtypes: print( "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" ) for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/conv_unaligned_bench.py ================================================ import math import time import mlx.core as mx import numpy as np import torch N_warmup = 10 N_iter_bench = 100 N_iter_func = 5 def bench(f, a, b): for i in range(N_warmup): f(a, b) torch.mps.synchronize() s = time.perf_counter_ns() for i in range(N_iter_bench): f(a, b) e = time.perf_counter_ns() return (e - s) * 1e-9 def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys return mx_conv_2D def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_2D(a, b): ys = [] for i in range(N_iter_func): y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) torch.mps.synchronize() return ys return pt_conv_2D def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( np_dtype ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") torch.mps.synchronize() f_mx = make_mx_conv_2D(strides, padding, groups) f_pt = make_pt_conv_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv2d( a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) atol = 2e-5 if np_dtype == np.float32 else 1e-4 if not np.allclose(out_pt, out_mx, atol=atol): print( f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch if __name__ == "__main__": dtype = "float32" shapes = ( (4, 32, 32, 21, 3, 3, 128), (4, 32, 32, 21, 3, 3, 37), (4, 32, 32, 370, 3, 3, 370), (4, 32, 32, 370, 7, 7, 128), (2, 320, 640, 21, 7, 7, 21), ) for N, H, W, C, kh, kw, O in shapes: time_mlx, time_torch = bench_shape( N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype ) diff = time_torch / time_mlx - 1.0 print( f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") ================================================ FILE: benchmarks/python/distributed_bench.py ================================================ # Copyright © 2024 Apple Inc. """ Run with: mpirun -n 2 python /path/to/distributed_bench.py """ import time import mlx.core as mx def time_fn(fn, *args, **kwargs): msg = kwargs.pop("msg", None) world = mx.distributed.init() if world.rank() == 0: if msg: print(f"Timing {msg} ...", end=" ") else: print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): mx.eval(fn(*args, **kwargs)) num_iters = 100 tic = time.perf_counter() for _ in range(num_iters): x = mx.eval(fn(*args, **kwargs)) toc = time.perf_counter() msec = 1e3 * (toc - tic) / num_iters if world.rank() == 0: print(f"{msec:.5f} msec") def time_all_sum(): shape = (4096,) x = mx.random.uniform(shape=shape) mx.eval(x) def sine(x): for _ in range(20): x = mx.sin(x) return x time_fn(sine, x) def all_sum_plain(x): for _ in range(20): x = mx.distributed.all_sum(x) return x time_fn(all_sum_plain, x) def all_sum_with_sine(x): for _ in range(20): x = mx.sin(x) x = mx.distributed.all_sum(x) return x time_fn(all_sum_with_sine, x) if __name__ == "__main__": time_all_sum() ================================================ FILE: benchmarks/python/einsum_bench.py ================================================ # Copyright © 2024 Apple Inc. import time import mlx.core as mx import numpy as np def timeit(fn, its=100, args=[]): for _ in range(5): fn(*args) tic = time.perf_counter() for _ in range(its): fn(*args) toc = time.perf_counter() return 1e3 * (toc - tic) / its def time_little_einsum_path(): subscripts = "ik,kj->ij" x = mx.ones((32, 32)) y = mx.ones((32, 32)) mx_time = timeit(mx.einsum_path, args=(subscripts, x, y)) x = np.array(x) y = np.array(y) np_time = timeit(np.einsum_path, args=(subscripts, x, y)) print("Timing little einsum path...") print(f"MLX ... {mx_time:.3f} ms") print(f"NumPy... {np_time:.3f} ms") def time_big_einsum_path(): chars = list("abcdefgh") char_to_dim = {c: v for v, c in enumerate(chars)} num_inputs = 10 inputs = [] subscripts = [] for _ in range(num_inputs): subscript = np.random.choice(chars, size=5, replace=False).tolist() subscripts.append("".join(subscript)) inputs.append(np.ones(list(char_to_dim[c] for c in subscript))) subscripts = ",".join(subscripts) np_time = timeit(np.einsum_path, args=(subscripts, *inputs)) inputs = [mx.array(x) for x in inputs] mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs)) print("Timing big einsum path...") print(f"MLX ... {mx_time:.3f} ms") print(f"NumPy... {np_time:.3f} ms") def time_attention(): def regular_attention(x): # shape [batch, sequence, num_heads, head_dim] queries, keys, values = x, x, x scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1) scores = mx.softmax(scores, axis=-1) output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2) mx.eval(output) def einsum_attention(x): # shape [batch, sequence, num_heads, head_dim] queries, keys, values = x, x, x scores = mx.einsum("itjk,iujk->ijtu", queries, keys) scores = mx.softmax(scores, axis=-1) output = mx.einsum("ijtu,iujk->itjk", scores, values) mx.eval(output) x = mx.random.uniform(shape=(8, 512, 32, 128)) regular_time = timeit(regular_attention, args=(x,)) ein_time = timeit(einsum_attention, args=(x,)) print("Timing einsum attention...") print(f"Regular ... {regular_time:.3f} ms") print(f"Einsum ... {ein_time:.3f} ms") if __name__ == "__main__": time_little_einsum_path() time_big_einsum_path() time_attention() ================================================ FILE: benchmarks/python/fft_bench.py ================================================ # Copyright © 2024 Apple Inc. import matplotlib import mlx.core as mx import numpy as np import sympy import torch from time_utils import measure_runtime matplotlib.use("Agg") import matplotlib.pyplot as plt def bandwidth_gb(runtime_ms, system_size): bytes_per_fft = np.dtype(np.complex64).itemsize * 2 bytes_per_gb = 1e9 ms_per_s = 1e3 return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb def run_bench(system_size, fft_sizes, backend="mlx", dim=1): def fft_mlx(x): if dim == 1: out = mx.fft.fft(x) elif dim == 2: out = mx.fft.fft2(x) mx.eval(out) return out def fft_mps(x): if dim == 1: out = torch.fft.fft(x) elif dim == 2: out = torch.fft.fft2(x) torch.mps.synchronize() return out bandwidths = [] for n in fft_sizes: batch_size = system_size // n**dim shape = [batch_size] + [n for _ in range(dim)] if backend == "mlx": x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64) x = mx.array(x_np) mx.eval(x) fft = fft_mlx elif backend == "mps": x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64) x = torch.tensor(x_np, device="mps") torch.mps.synchronize() fft = fft_mps else: raise NotImplementedError() runtime_ms = measure_runtime(fft, x=x) bandwidth = bandwidth_gb(runtime_ms, np.prod(shape)) print(n, bandwidth) bandwidths.append(bandwidth) return np.array(bandwidths) def time_fft(): x = np.array(range(2, 512)) system_size = int(2**26) print("MLX GPU") with mx.stream(mx.gpu): gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x) print("MPS GPU") mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps") print("CPU") system_size = int(2**20) with mx.stream(mx.cpu): cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x) x = np.array(x) all_indices = x - x[0] radix_2to13 = ( np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0] ) bluesteins = ( np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0] ) for indices, name in [ (all_indices, "All"), (radix_2to13, "Radix 2-13"), (bluesteins, "Bluestein's"), ]: # plot bandwidths print(name) plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU") plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS") plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU") plt.title(f"MLX FFT Benchmark -- {name}") plt.xlabel("N") plt.ylabel("Bandwidth (GB/s)") plt.legend() plt.savefig(f"{name}.png") plt.clf() av_gpu_bandwidth = np.mean(gpu_bandwidths) av_mps_bandwidth = np.mean(mps_bandwidths) av_cpu_bandwidth = np.mean(cpu_bandwidths) print("Average bandwidths:") print("GPU:", av_gpu_bandwidth) print("MPS:", av_mps_bandwidth) print("CPU:", av_cpu_bandwidth) portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x) print("Percent MLX faster than MPS: ", portion_faster * 100) if __name__ == "__main__": time_fft() ================================================ FILE: benchmarks/python/gather_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import argparse import mlx.core as mx import torch from time_utils import measure_runtime def benchmark_gather_mlx(x_shape, idx_shape): def gather(x, idx): mx.eval(x[idx]) idx = mx.random.randint(0, x_shape[0] - 1, idx_shape) x = mx.random.normal(x_shape).astype(mx.float32) runtime = measure_runtime(gather, x=x, idx=idx) print(f"MLX: {runtime:.3f}ms") def benchmark_gather_torch(x_shape, idx_shape, device): def gather(x, idx, device): _ = x[idx] if device == torch.device("mps"): torch.mps.synchronize() idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device) x = torch.randn(x_shape, dtype=torch.float32).to(device) runtime = measure_runtime(gather, x=x, idx=idx, device=device) print(f"PyTorch: {runtime:.3f}ms") if __name__ == "__main__": parser = argparse.ArgumentParser("Gather benchmarks.") parser.add_argument("--cpu", action="store_true", help="Use the CPU.") args = parser.parse_args() if args.cpu: mx.set_default_device(mx.cpu) device = torch.device("cpu") else: device = torch.device("mps") idx_shapes = [(1_000_000,), (100_000,), ()] x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)] for x_shape, idx_shape in zip(x_shapes, idx_shapes): print("=" * 20) print(f"X {x_shape}, Indices {idx_shape}") benchmark_gather_mlx(x_shape, idx_shape) benchmark_gather_torch(x_shape, idx_shape, device=device) ================================================ FILE: benchmarks/python/gather_mm_bench.py ================================================ # Copyright © 2025 Apple Inc. import mlx.core as mx from time_utils import time_fn N = 1024 D = 1024 M = 1024 E = 32 I = 4 def gather_sort(x, indices): N, M = indices.shape indices = indices.flatten() order = mx.argsort(indices) inv_order = mx.argsort(order) return x.flatten(0, -3)[order // M], indices[order], inv_order def scatter_unsort(x, inv_order, shape=None): x = x[inv_order] if shape is not None: x = mx.unflatten(x, 0, shape) return x def gather_mm_simulate(x, w, indices): x, idx, inv_order = gather_sort(x, indices) for i in range(2): y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0) x = y[:, None] x = scatter_unsort(x, inv_order, indices.shape) return x def time_gather_mm(): x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 w1 = mx.random.normal((E, M, D)) / 1024**0.5 w2 = mx.random.normal((E, D, M)) / 1024**0.5 indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) sorted_indices = mx.sort(indices.flatten()).reshape(N, I) mx.eval(x, w1, w2, indices, sorted_indices) def gather_mm(x, w1, w2, indices, sort): idx = indices inv_order = None if sort: x, idx, inv_order = gather_sort(x, indices) x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) if sort: x = scatter_unsort(x, inv_order, indices.shape) return x time_fn(gather_mm, x, w1, w2, indices, False) time_fn(gather_mm, x, w1, w2, sorted_indices, False) time_fn(gather_mm, x, w1, w2, indices, True) x = mx.random.normal((N * I, D)) / 1024**0.5 w1 = mx.random.normal((M, D)) / 1024**0.5 w2 = mx.random.normal((D, M)) / 1024**0.5 mx.eval(x, w1, w2) def equivalent_matmul(x, w1, w2): x = x @ w1.T x = x @ w2.T return x time_fn(equivalent_matmul, x, w1, w2) if __name__ == "__main__": time_gather_mm() ================================================ FILE: benchmarks/python/gather_qmm_bench.py ================================================ # Copyright © 2025 Apple Inc. import mlx.core as mx from time_utils import time_fn N = 1024 D = 1024 M = 1024 E = 32 I = 4 def gather_sort(x, indices): N, M = indices.shape indices = indices.flatten() order = mx.argsort(indices) inv_order = mx.argsort(order) return x.flatten(0, -3)[order // M], indices[order], inv_order def scatter_unsort(x, inv_order, shape=None): x = x[inv_order] if shape is not None: x = mx.unflatten(x, 0, shape) return x def gather_mm_simulate(x, w, indices): x, idx, inv_order = gather_sort(x, indices) for i in range(2): y = mx.concatenate( [ mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True) for i, j in enumerate(idx.tolist()) ], axis=0, ) x = y[:, None] x = scatter_unsort(x, inv_order, indices.shape) return x def time_gather_qmm(): x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 w1 = mx.random.normal((E, M, D)) / 1024**0.5 w2 = mx.random.normal((E, D, M)) / 1024**0.5 w1 = mx.quantize(w1) w2 = mx.quantize(w2) indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) sorted_indices = mx.sort(indices.flatten()).reshape(N, I) mx.eval(x, w1, w2, indices, sorted_indices) def gather_mm(x, w1, w2, indices, sort): idx = indices inv_order = None if sort: x, idx, inv_order = gather_sort(x, indices) x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) if sort: x = scatter_unsort(x, inv_order, indices.shape) return x time_fn(gather_mm, x, w1, w2, indices, False) time_fn(gather_mm, x, w1, w2, sorted_indices, False) time_fn(gather_mm, x, w1, w2, indices, True) x = mx.random.normal((N * I, D)) / 1024**0.5 w1 = mx.random.normal((M, D)) / 1024**0.5 w2 = mx.random.normal((D, M)) / 1024**0.5 w1 = mx.quantize(w1) w2 = mx.quantize(w2) mx.eval(x, w1, w2) def equivalent_matmul(x, w1, w2): x = mx.quantized_matmul(x, *w1, transpose=True) x = mx.quantized_matmul(x, *w2, transpose=True) return x time_fn(equivalent_matmul, x, w1, w2) if __name__ == "__main__": time_gather_qmm() ================================================ FILE: benchmarks/python/hadamard_bench.py ================================================ import argparse import matplotlib import mlx.core as mx import numpy as np from time_utils import measure_runtime matplotlib.use("Agg") import matplotlib.pyplot as plt def had(x): y = mx.hadamard_transform(x) mx.eval(y) def copy(x): y = x + 1.0 mx.eval(y) def run(dtype): system_size = 2**26 outputs = {} for test_fn in (had, copy): for m in [1, 12, 20, 28]: if test_fn == copy: key = "copy" elif m == 1: key = "had_2^k" else: key = "had_m*2^k" outputs.setdefault(key, {}) for k in range(7, 14): n = m * 2**k if n > 2**15: continue x_np = np.random.normal(size=(system_size // n, n)).astype(dtype) x = mx.array(x_np) runtime_ms = measure_runtime(test_fn, x=x) bytes_per_gb = 1e9 ms_per_s = 1e3 bytes_per_had = np.dtype(x_np.dtype).itemsize * 2 bandwidth_gb = ( system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb ) print(n, bandwidth_gb) outputs[key][n] = bandwidth_gb colors = { "copy": "black", "had_2^k": "steelblue", "had_m*2^k": "skyblue", } for key, output in outputs.items(): plt.scatter(output.keys(), output.values(), color=colors[key], label=key) plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}") plt.xlabel("N") plt.ylabel("Bandwidth (GB/s)") plt.legend() plt.savefig(f"bench_{dtype.__name__}.png") plt.clf() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fp16", action="store_true") args = parser.parse_args() dtype = np.float16 if args.fp16 else np.float32 run(dtype) ================================================ FILE: benchmarks/python/large_gemm_bench.py ================================================ # Copyright © 2026 Apple Inc. import math import time import mlx.core as mx import numpy as np import torch N_WARMUP = 5 N_BENCH = 20 def bench_mlx(a, b): for _ in range(N_WARMUP): mx.eval(a @ b) times = [] for _ in range(N_BENCH): start = time.perf_counter_ns() mx.eval(a @ b) end = time.perf_counter_ns() times.append((end - start) * 1e-9) return np.mean(times), np.std(times) @torch.no_grad() def bench_torch(a, b): for _ in range(N_WARMUP): _ = a @ b torch.mps.synchronize() times = [] for _ in range(N_BENCH): start = time.perf_counter_ns() _ = a @ b torch.mps.synchronize() end = time.perf_counter_ns() times.append((end - start) * 1e-9) return np.mean(times), np.std(times) def check_correctness(out_mx, out_pt, rtol, M, N, K): if not np.allclose(out_pt, out_mx, rtol=rtol, atol=0): abs_diff = np.abs(out_pt - out_mx) rel_diff = abs_diff / np.maximum(np.abs(out_pt), 1e-10) print( f" WARNING: Correctness failed at {M}x{N}x{K}: " f"max_abs={np.max(abs_diff):.6e}, max_rel={np.max(rel_diff):.6e}" ) def bench_gemm(M, N, K, dtype, rtol): scale = 0.5 / math.sqrt(K) a_np = np.random.uniform(0, scale, (M, K)).astype(np.float32) b_np = np.random.uniform(0, scale, (K, N)).astype(np.float32) a_mx = mx.array(a_np).astype(getattr(mx, dtype)) b_mx = mx.array(b_np).astype(getattr(mx, dtype)) a_pt = torch.from_numpy(a_np).to(dtype=getattr(torch, dtype), device="mps") b_pt = torch.from_numpy(b_np).to(dtype=getattr(torch, dtype), device="mps") torch.mps.synchronize() torch_mean, torch_std = bench_torch(a_pt, b_pt) mlx_mean, mlx_std = bench_mlx(a_mx, b_mx) out_mx = (a_mx @ b_mx).astype(mx.float32) out_pt = (a_pt @ b_pt).to(torch.float32).to("cpu").numpy(force=True) check_correctness(out_mx, out_pt, rtol, M, N, K) return mlx_mean, mlx_std, torch_mean, torch_std if __name__ == "__main__": dtypes = ("bfloat16", "float16", "float32") rtols = { "float32": 1e-3, "float16": 5e-3, "bfloat16": 1e-2, } shapes = ( (2048, 2048, 10240), (2048, 3072, 10240), (3072, 3072, 10240), (3072, 3072, 12288), (3072, 4096, 12288), (4096, 4096, 12288), (4096, 4096, 18432), (4096, 4096, 21504), (4096, 6144, 21504), (6144, 6144, 21504), ) for dtype in dtypes: print(f"\nPerformance ({dtype}):") print( f"{'M':>5s} {'N':>5s} {'K':>6s} " f"{'MLX (ms)':>15s} {'Torch (ms)':>15s} {'Speedup':>10s}" ) print("-" * 80) for M, N, K in shapes: mlx_mean, mlx_std, torch_mean, torch_std = bench_gemm( M, N, K, dtype, rtols[dtype] ) speedup = torch_mean / mlx_mean print( f"{M:5d} {N:5d} {K:6d} " f"{mlx_mean*1000:7.2f}±{mlx_std*1000:5.2f} " f"{torch_mean*1000:7.2f}±{torch_std*1000:5.2f} " f"{speedup:8.2f}x" ) ================================================ FILE: benchmarks/python/layer_norm_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. from functools import partial import mlx.core as mx import mlx.nn as nn from time_utils import time_fn def layer_norm(x, w, b, eps): ot = x.dtype x = x.astype(mx.float32) mu = mx.mean(x, -1, keepdims=True) v = mx.var(x, -1, keepdims=True) y = (x - mu) * mx.rsqrt(v + eps) if w is not None: y = y * w if b is not None: y = y + b return y def time_layer_norm(N, dt): L = 1024 f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2)) x = mx.random.uniform(shape=(8, L, N)).astype(dt) w = mx.random.uniform(shape=(N,)).astype(dt) b = mx.random.uniform(shape=(N,)).astype(dt) y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) def layer_norm_loop(f, x, w, b): for _ in range(32): x = f(x, w, b) return x time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b) time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b) def layer_norm_grad_loop(g, x, w, b): gx, gw, gb = x, w, b for _ in range(32): gx, gw, gb = g(gx, gw, gb, y) return gx, gw, gb time_fn(layer_norm_grad_loop, g1, x, w, b) time_fn(layer_norm_grad_loop, g2, x, w, b) time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) x = mx.random.uniform(shape=(8, L, N)).astype(dt) w = mx.random.uniform(shape=(N,)).astype(dt) b = mx.random.uniform(shape=(N,)).astype(dt) y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) def layer_norm_grad_x_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx time_fn(layer_norm_grad_x_loop, g1, x) time_fn(layer_norm_grad_x_loop, g2, x) time_fn(layer_norm_grad_x_loop, mx.compile(g1), x) time_fn(layer_norm_grad_x_loop, mx.compile(g2), x) if __name__ == "__main__": for dt in [mx.float32, mx.float16, mx.bfloat16]: for n in [1024, 2048, 4096, 8192, 8192 + 1024]: print(dt, n) time_layer_norm(n, dt) ================================================ FILE: benchmarks/python/masked_scatter.py ================================================ import math import os import platform import subprocess import time from copy import copy from functools import partial import matplotlib.pyplot as plt import mlx.core as mx import numpy as np import torch from matplotlib.ticker import FuncFormatter RESULTS_DIR = "./results" if not os.path.isdir(RESULTS_DIR): os.mkdir(RESULTS_DIR) TORCH_DEVICE = torch.device( "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") ) def get_device_name(): if TORCH_DEVICE.type == "cuda": try: out = subprocess.check_output( ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], stderr=subprocess.DEVNULL, ) return out.decode("utf-8").splitlines()[0].strip() except Exception: return "CUDA_GPU" if TORCH_DEVICE.type == "mps": try: out = subprocess.check_output( ["sysctl", "-n", "machdep.cpu.brand_string"], stderr=subprocess.DEVNULL, ) return out.decode("utf-8").strip() except Exception: return "Apple_Silicon" return platform.processor() or platform.machine() or "CPU" DEVICE_NAME = get_device_name() N_WARMUP = 5 N_ITER_BENCH = 50 N_ITER_FUNC = 20 VECTOR_LENGTHS = [4096 * (2**i) for i in range(12)] MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5] D_TYPES = ("float32", "float16") def _power_of_two_formatter(value, _position): if value <= 0: return "" exponent = int(round(math.log2(value))) if abs(value - (1 << exponent)) / value > 1e-6: return f"{value:g}" return f"$2^{{{exponent}}}$" def torch_sync(): if TORCH_DEVICE.type == "cuda": torch.cuda.synchronize() elif TORCH_DEVICE.type == "mps": torch.mps.synchronize() def masked_scatter_mlx(self_arr, mask_arr, src_arr): outs = [] for _ in range(N_ITER_FUNC): out = copy(self_arr) out[mask_arr] = src_arr outs.append(out) mx.eval(outs) return outs @torch.no_grad() def masked_scatter_torch(self_tensor, mask_tensor, src_tensor): outs = [] for _ in range(N_ITER_FUNC): out = self_tensor.clone() out.masked_scatter_(mask_tensor, src_tensor) outs.append(out) torch_sync() return outs def measure(fn): for _ in range(N_WARMUP): fn() start = time.perf_counter_ns() for _ in range(N_ITER_BENCH): fn() end = time.perf_counter_ns() return (end - start) * 1e-9 def bytes_touched(length, true_count, item_size): mask_bytes = length self_bytes = length * item_size * 2 # read + write src_bytes = true_count * item_size return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH def build_case(length, density, np_dtype, torch_dtype): true_count = max(1, int(round(length * density))) rng = np.random.default_rng() self_np = rng.normal(0.0, 1.0, length).astype(np_dtype) mask_np = np.zeros(length, dtype=bool) mask_np[:true_count] = True rng.shuffle(mask_np) src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype) self_mlx = mx.array(self_np) mask_mlx = mx.array(mask_np) src_mlx = mx.array(src_np) self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype) mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE) src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype) # Correctness check once per configuration mx_out = mx.array(self_np) mx_out[mask_mlx] = src_mlx mx.eval(mx_out) torch_out = self_torch.clone() torch_out.masked_scatter_(mask_torch, src_torch) atol = 5e-3 if np_dtype == np.float16 else 1e-5 if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol): raise AssertionError("masked_scatter results diverged between MLX and Torch") return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count) def bench_case(length, density, dtype): np_dtype = getattr(np, dtype) torch_dtype = getattr(torch, dtype) ( self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count, ) = build_case(length, density, np_dtype, torch_dtype) time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx)) time_torch = measure( partial(masked_scatter_torch, self_torch, mask_torch, src_torch) ) total_bytes = bytes_touched(length, true_count, np_dtype().itemsize) bytes_per_gb = float(1024**3) mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx torch_gbps = (total_bytes / bytes_per_gb) / time_torch return time_mlx, time_torch, mlx_gbps, torch_gbps def plot_density(ax_perf, ax_speedup, density, dtype): mlx_gbps = [] torch_gbps = [] mlx_times = [] torch_times = [] for length in VECTOR_LENGTHS: t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype) mlx_gbps.append(gbps_mlx) torch_gbps.append(gbps_torch) mlx_times.append(t_mlx) torch_times.append(t_torch) ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX") ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch") ax_perf.set_xscale("log", base=2) ax_perf.set_xticks(VECTOR_LENGTHS) formatter = FuncFormatter(_power_of_two_formatter) ax_perf.xaxis.set_major_formatter(formatter) ax_perf.set_title(f"density={density:.2f}") ax_perf.set_ylabel("GB/s") ax_perf.grid(True, which="both", linestyle=":", alpha=0.4) ax_perf.legend() speedup = np.array(torch_times) / np.array(mlx_times) ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green") ax_speedup.axhline(1.0, color="tab:gray", linestyle="--") ax_speedup.set_xscale("log", base=2) ax_speedup.set_xticks(VECTOR_LENGTHS) ax_speedup.xaxis.set_major_formatter(formatter) ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)") ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4) def main(): for dtype in D_TYPES: fig, axs = plt.subplots( len(MASK_DENSITIES), 2, figsize=(10, 12), layout="constrained", sharex=True, ) for i, density in enumerate(MASK_DENSITIES): plot_density(axs[i][0], axs[i][1], density, dtype) axs[i][0].set_xlabel("vector length") axs[i][1].set_xlabel("vector length") fig.suptitle( f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}" ) output_path = os.path.join( RESULTS_DIR, f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.png", ) fig.savefig(output_path) print(f"Saved benchmark image: {output_path}") plt.close(fig) if __name__ == "__main__": main() ================================================ FILE: benchmarks/python/rms_norm_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import mlx.core as mx import mlx.nn as nn from time_utils import time_fn def rms_norm(x, w, eps): ot = x.dtype x = x.astype(mx.float32) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) y = (x * n).astype(ot) if w is not None: y = y * w return y def time_rms_norm(): f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum() f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1)) g2 = mx.grad(f2, argnums=(0, 1)) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) w = mx.random.uniform(shape=(4096,)).astype(mx.float16) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) mx.eval(x, w, y) def rms_norm_loop(g, x, w): gx, gw = x, w for _ in range(32): gx, gw = g(gx, gw, y) return gx, gw time_fn(rms_norm_loop, g1, x, w) time_fn(rms_norm_loop, g2, x, w) time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w) f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) w = mx.random.uniform(shape=(4096,)).astype(mx.float16) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) mx.eval(x, w, y) def rms_norm_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx time_fn(rms_norm_loop, g1, x) time_fn(rms_norm_loop, g2, x) time_fn(rms_norm_loop, mx.compile(g1), x) time_fn(rms_norm_loop, mx.compile(g2), x) if __name__ == "__main__": time_rms_norm() ================================================ FILE: benchmarks/python/rope_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import mlx.core as mx import mlx.nn as nn from time_utils import time_fn def time_rope(): rope = nn.RoPE(64) # vec x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) mx.eval(x) def rope_vec(x): for _ in range(32): x = rope(x, offset=100) return x time_fn(rope_vec, x) # matrix x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) mx.eval(x) def rope_mat(x): for _ in range(32): x = rope(x) return x time_fn(rope_mat, x) if __name__ == "__main__": time_rope() ================================================ FILE: benchmarks/python/scatter_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import argparse import mlx.core as mx import torch from time_utils import measure_runtime def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def scatter(dst, x, idx): dst[tuple(idx)] = x mx.eval(dst) idx = [] for idx_shape in idx_shapes: idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape)) x = mx.random.normal(x_shape).astype(mx.float32) dst = mx.random.normal(dst_shape).astype(mx.float32) runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx) print(f"MLX: {runtime:.3f}ms") def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): def scatter(dst, x, idx, device): dst[tuple(idx)] = x if device == torch.device("mps"): torch.mps.synchronize() idx = [] for idx_shape in idx_shapes: idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)) x = torch.randn(x_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device) runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device) print(f"PyTorch: {runtime:.3f}ms") if __name__ == "__main__": parser = argparse.ArgumentParser("Gather benchmarks.") parser.add_argument("--cpu", action="store_true", help="Use the CPU.") args = parser.parse_args() if args.cpu: mx.set_default_device(mx.cpu) device = torch.device("cpu") else: device = torch.device("mps") dst_shapes = [ (10, 64), (100_000, 64), (1_000_000, 64), (100_000,), (200_000,), (20_000_000,), (10000, 64), (100, 64), (100, 10_000, 64), (10, 100, 100, 21), (1_000, 1_000, 10), ] idx_shapes = [ [(1_000_000,)], [(1_000_000,)], [(100_000,)], [(1_000_000,)], [(20_000_000,)], [(20_000_000,)], [(1000000,)], [(10000000,)], [(1_000,)], [(10_000,)], [(1_000,), (1_000,)], ] x_shapes = [ (1_000_000, 64), (1_000_000, 64), (100_000, 64), (1_000_000,), (20_000_000,), (20_000_000,), (1000000, 64), (10000000, 64), (1_000, 10_000, 64), (10_000, 100, 100, 21), (1_000, 10), ] for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): print("=" * 20) print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}") benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) ================================================ FILE: benchmarks/python/sdpa_bench.py ================================================ # Copyright © 2024 Apple Inc. import argparse import math import os import subprocess import time import mlx.core as mx import numpy as np device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) device_name = device_name.decode("utf-8").strip("\n") N_warmup = 5 N_iter_bench = 40 N_iter_func = 8 def bench(f, *args): for i in range(N_warmup): f(*args) s = time.perf_counter_ns() for i in range(N_iter_bench): f(*args) e = time.perf_counter_ns() return (e - s) * 1e-9 def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): np_dtype = getattr(np, dtype) shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) scale = 1.0 / math.sqrt(D) q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) q_mx = mx.array(q_np) k_mx = mx.array(k_np) v_mx = mx.array(v_np) if mask is not None: if mask == "additive": mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype) mask = mx.array(mask_np) elif mask == "bool": mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 mask = mx.array(mask_np) return q_mx, k_mx, v_mx, scale, mask def mlx_ref_attn(q, k, v, scale=1.0, mask=None): q_dtype = q.dtype q = q * mx.array(scale, q_dtype) n_q_heads = q.shape[-3] n_kv_heads = k.shape[-3] n_repeats = n_q_heads // n_kv_heads B = q.shape[0] L = q.shape[2] kL = k.shape[2] if n_repeats > 1: q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) k = mx.expand_dims(k, 2) v = mx.expand_dims(v, 2) scores = q @ mx.swapaxes(k, -1, -2) if mask is not None: if mask == "causal": q_offset = max(0, kL - L) q_indices = mx.arange(q_offset, q_offset + L) k_indices = mx.arange(kL) mask = q_indices[:, None] >= k_indices[None] if n_repeats > 1 and mask.ndim >= 3: if mask.shape[-3] == 1: mask = mx.expand_dims(mask, -3) else: mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) if mask.dtype == mx.bool_: scores = mx.where(mask, scores, -np.float32(np.inf)) else: scores += mask scores = mx.softmax(scores, axis=-1, precise=True) out = scores @ v if n_repeats > 1: out = mx.reshape(out, [B, n_q_heads, L, -1]) return out def mlx_fused_attn(q, k, v, scale, mask): return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) def do_attention(f, q, k, v, scale, mask=None, transpose=False): if transpose: q_t = mx.transpose(q, (0, 2, 1, 3)) k_t = mx.transpose(k, (0, 2, 1, 3)) v_t = mx.transpose(v, (0, 2, 1, 3)) o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) return mx.transpose(o_t, (0, 2, 1, 3)) else: return f(q, k, v, scale=scale, mask=mask) def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False): q_out = q for i in range(N_iter_func): q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose) mx.eval(q_out) return q_out def bench_shape( B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None ): q_mx, k_mx, v_mx, scale, mask = prepare_inputs( B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype ) time_mlx_unfused = bench( do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose ) time_mlx_fused = bench( do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose ) o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose) o_mlx_unfused = do_attention( mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose ) atol = 1e-5 if dtype == "float32" else 2e-4 if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol): print( f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" ) return time_mlx_fused, time_mlx_unfused def get_gflop_count(B, M, N, K): return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run gemm benchmarks") dtypes = ("float16", "float32")[:1] transposes = (False,) # fmt: off shapes_64 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 32, 32, 64, 32, 32), ( 1, 64, 64, 64, 32, 32), ( 1, 128, 128, 64, 32, 32), ( 1, 256, 256, 64, 32, 32), ( 1, 512, 512, 64, 32, 32), ( 1, 1024, 1024, 64, 32, 8), ( 1, 2048, 2048, 64, 32, 8), ( 1, 4096, 4096, 64, 32, 8), ) shapes_80 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 1024, 1024, 80, 32, 8), ( 1, 2048, 2048, 80, 32, 8), ( 1, 4096, 4096, 80, 32, 8), ) shapes_128 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 1024, 1024, 128, 32, 8), ( 1, 2048, 2048, 128, 32, 8), ( 1, 4096, 4096, 128, 32, 8), ) # fmt: on shapes = shapes_64 + shapes_80 + shapes_128 masks = [None, "bool", "causal"] print( " B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%" ) for dtype in dtypes: for transpose in transposes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: for mask_in in masks: time_mlx_fused, time_mlx_unfused = bench_shape( B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose, mask_in, ) diff = time_mlx_unfused / time_mlx_fused - 1.0 t_str = 1 if transpose else 0 print( f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" ) ================================================ FILE: benchmarks/python/sdpa_vector_bench.py ================================================ import argparse import math import mlx.core as mx from time_utils import time_fn L = 16384 H = 32 H_k = H // 4 D = 128 V = 128 dtype = mx.float16 loops = 10 def upproject(x, w): if w is None: return x else: return x @ w.T def attention(q, k, v, mask=None, w=None): def _sdpa(q, k, v): B, Hq, L, D = q.shape _, Hk, S, _ = k.shape _, _, _, V = v.shape q = q.reshape(B, Hk, Hq // Hk, L, D) k = k[:, :, None, :, :] v = v[:, :, None, :, :] s = q @ k.transpose(0, 1, 2, 4, 3) if mask is not None: m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S) s = mx.where(m, s, mx.finfo(s.dtype).min) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) o = p @ v return o.reshape(B, Hq, L, V) for i in range(loops): q = _sdpa(q, k, v) q = upproject(q, w) return q def sdpa(q, k, v, mask=None, w=None): for i in range(loops): q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) q = upproject(q, w) return q def time_self_attention_primitives(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mx.eval(q, k, v, w) time_fn(attention, q, k, v, w=w) def time_self_attention_sdpa(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mx.eval(q, k, v, w) time_fn(sdpa, q, k, v, w=w) def time_self_attention_sdpa_with_mask(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mask = mx.full((L,), True) mask[L // 2 :] = False mx.eval(q, k, v, mask, w) def sdpa_mask(*args): return sdpa(*args, mask=mask, w=w) def attention_mask(*args): return attention(*args, mask=mask, w=w) time_fn(attention_mask, q, k, v) time_fn(sdpa_mask, q, k, v) if __name__ == "__main__": time_self_attention_sdpa() time_self_attention_primitives() time_self_attention_sdpa_with_mask() ================================================ FILE: benchmarks/python/segmented_mm_bench.py ================================================ # Copyright © 2026 Apple Inc. import argparse import time import mlx.core as mx import numpy as np MLX_DTYPES = { "float16": mx.float16, "bfloat16": mx.bfloat16, "float32": mx.float32, } def parse_cases(cases): parsed = [] for spec in cases.split(","): m, n, k, s = [int(x) for x in spec.split("x")] parsed.append((m, n, k, s)) return parsed def make_segments(k, num_segments, pattern, seed): if pattern == "equal": cuts = np.linspace(0, k, num_segments + 1, dtype=np.int64) else: rng = np.random.default_rng(seed) cuts = rng.integers(0, k + 1, size=(num_segments - 1,), dtype=np.int64) cuts = np.sort(cuts) cuts = np.concatenate(([0], cuts, [k])) return np.stack([cuts[:-1], cuts[1:]], axis=1).astype(np.uint32) def numpy_segmented_mm_ref(a, b, segments): """Ground-truth reference in float64.""" out = [] for start, end in segments: out.append(a[:, start:end] @ b[start:end, :]) return np.stack(out, axis=0) def mlx_segmented_mm_loop(a, b, segments): """MLX loop-of-matmuls baseline.""" segments_list = segments.tolist() out = [] for start, end in segments_list: out.append(a[:, start:end] @ b[start:end, :]) return mx.stack(out, axis=0) def bench_mlx(a, b, segments, warmup, iters): for _ in range(warmup): y = mx.segmented_mm(a, b, segments) mx.eval(y) mx.synchronize() start = time.perf_counter() for _ in range(iters): y = mx.segmented_mm(a, b, segments) mx.eval(y) mx.synchronize() end = time.perf_counter() return (end - start) * 1e3 / iters def bench_mlx_loop(a, b, segments, warmup, iters): for _ in range(warmup): y = mlx_segmented_mm_loop(a, b, segments) mx.eval(y) mx.synchronize() start = time.perf_counter() for _ in range(iters): y = mlx_segmented_mm_loop(a, b, segments) mx.eval(y) mx.synchronize() end = time.perf_counter() return (end - start) * 1e3 / iters def print_table(headers, rows): widths = [len(h) for h in headers] for row in rows: for i, cell in enumerate(row): widths[i] = max(widths[i], len(cell)) def fmt_row(row): return ( "| " + " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row)) + " |" ) sep = "|-" + "-|-".join("-" * w for w in widths) + "-|" print(fmt_row(headers)) print(sep) for row in rows: print(fmt_row(row)) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--cases", default=( "128x128x1024x16," "128x128x1024x32," "256x256x2048x16," "512x512x4096x32," "1024x1024x4096x32," "1024x1024x8192x64" ), help="Comma-separated MxNxKxS list.", ) parser.add_argument( "--dtype", default="float32", choices=["float16", "bfloat16", "float32"], ) parser.add_argument("--warmup", type=int, default=10) parser.add_argument("--iters", type=int, default=50) parser.add_argument( "--segments", choices=["equal", "random"], default="random", help="Segment generation pattern.", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--no-check", action="store_true") args = parser.parse_args() mlx_dtype = MLX_DTYPES[args.dtype] print( f"dtype={args.dtype} warmup={args.warmup} iters={args.iters} segments={args.segments}" ) headers = [ "Case", "MLX ms", "Loop ms", "Speedup", "MLX err", "Loop err", ] rows = [] cases = parse_cases(args.cases) for idx, (m, n, k, s) in enumerate(cases): rng = np.random.default_rng(args.seed + idx) a_np = rng.standard_normal((m, k)).astype(np.float32) b_np = rng.standard_normal((k, n)).astype(np.float32) seg_np = make_segments(k, s, args.segments, args.seed + idx) a_mx = mx.array(a_np, dtype=mlx_dtype) b_mx = mx.array(b_np, dtype=mlx_dtype) seg_mx = mx.array(seg_np, dtype=mx.uint32) mx.eval(a_mx, b_mx, seg_mx) mlx_err_str = "" loop_err_str = "" if not args.no_check: y_mlx = mx.segmented_mm(a_mx, b_mx, seg_mx) y_loop = mlx_segmented_mm_loop(a_mx, b_mx, seg_mx) mx.eval(y_mlx, y_loop) if args.dtype == "float32": ref = numpy_segmented_mm_ref( a_np.astype(np.float64), b_np.astype(np.float64), seg_np.tolist(), ) mlx_err = np.max(np.abs(np.array(y_mlx, dtype=np.float64) - ref)) loop_err = np.max(np.abs(np.array(y_loop, dtype=np.float64) - ref)) else: a_mx_f32 = mx.array(a_np, dtype=mx.float32) b_mx_f32 = mx.array(b_np, dtype=mx.float32) ref = mx.segmented_mm(a_mx_f32, b_mx_f32, seg_mx) mx.eval(ref) mlx_err = float(mx.max(mx.abs(ref - y_mlx.astype(mx.float32))).item()) loop_err = float(mx.max(mx.abs(ref - y_loop.astype(mx.float32))).item()) mlx_err_str = f"{mlx_err:.2e}" loop_err_str = f"{loop_err:.2e}" t_mlx = bench_mlx(a_mx, b_mx, seg_mx, args.warmup, args.iters) t_loop = bench_mlx_loop(a_mx, b_mx, seg_mx, args.warmup, args.iters) ratio = t_loop / t_mlx if t_mlx > 0 else float("inf") rows.append( [ f"{m}x{n}x{k}x{s}", f"{t_mlx:.3f}", f"{t_loop:.3f}", f"{ratio:.2f}x", mlx_err_str, loop_err_str, ] ) print_table(headers, rows) if not args.no_check: if args.dtype == "float32": print("err: max|result - numpy_fp64_ref|") else: print("err: max|result - own_fp32_result|") if __name__ == "__main__": main() ================================================ FILE: benchmarks/python/single_ops.py ================================================ # Copyright © 2023 Apple Inc. import argparse import mlx.core as mx from time_utils import time_fn def time_add(): a = mx.random.uniform(shape=(32, 1024, 1024)) b = mx.random.uniform(shape=(32, 1024, 1024)) mx.eval(a, b) time_fn(mx.add, a, b) aT = mx.transpose(a, [0, 2, 1]) mx.eval(aT) def transpose_add(a, b): return mx.add(a, b) time_fn(transpose_add, aT, b) b = mx.random.uniform(shape=(1024,)) mx.eval(b) def slice_add(a, b): return mx.add(a, b) time_fn(slice_add, a, b) b = mx.reshape(b, (1, 1024, 1)) mx.eval(b) def mid_slice_add(a, b): return mx.add(a, b) time_fn(mid_slice_add, a, b) def time_matmul(): a = mx.random.uniform(shape=(1024, 1024)) b = mx.random.uniform(shape=(1024, 1024)) mx.eval(a, b) time_fn(mx.matmul, a, b) def time_maximum(): a = mx.random.uniform(shape=(32, 1024, 1024)) b = mx.random.uniform(shape=(32, 1024, 1024)) mx.eval(a, b) time_fn(mx.maximum, a, b) def time_max(): a = mx.random.uniform(shape=(32, 1024, 1024)) a[1, 1] = mx.nan mx.eval(a) time_fn(mx.max, a, 0) def time_min(): a = mx.random.uniform(shape=(32, 1024, 1024)) a[1, 1] = mx.nan mx.eval(a) time_fn(mx.min, a, 0) def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) def negative(a): return -a mx.eval(a) time_fn(negative, a) def time_exp(): a = mx.random.uniform(shape=(1000, 100)) mx.eval(a) time_fn(mx.exp, a) def time_logsumexp(): a = mx.random.uniform(shape=(64, 10, 10000)) mx.eval(a) time_fn(mx.logsumexp, a, axis=-1) def time_take(): a = mx.random.uniform(shape=(10000, 500)) ids = mx.random.randint(low=0, high=10000, shape=(20, 10)) ids = [mx.reshape(idx, (-1,)) for idx in ids] mx.eval(ids) def random_take(): return [mx.take(a, idx, 0) for idx in ids] time_fn(random_take) def time_reshape_transposed(): x = mx.random.uniform(shape=(256, 256, 128)) mx.eval(x) def reshape_transposed(): return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,)) time_fn(reshape_transposed) if __name__ == "__main__": parser = argparse.ArgumentParser("MLX benchmarks.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") args = parser.parse_args() if args.gpu: mx.set_default_device(mx.gpu) else: mx.set_default_device(mx.cpu) time_add() time_matmul() time_min() time_max() time_maximum() time_exp() time_negative() time_logsumexp() time_take() time_reshape_transposed() ================================================ FILE: benchmarks/python/slice_update_bench.py ================================================ # Copyright © 2023-2024 Apple Inc. import argparse import mlx.core as mx import torch from time_utils import measure_runtime def benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtype, iters=10): def slice_update(arguments): for i in range(iters): arguments["dst"] = ( arguments["dst"].at[slice_range].add(arguments["updates"]) ) mx.eval(arguments) dtype = getattr(mx, dtype) arguments = { "dst": mx.random.normal(dst_shape).astype(dtype), "updates": mx.random.normal(slice_shape).astype(dtype), } runtime = measure_runtime(slice_update, arguments=arguments) bytes_processed = ( arguments["dst"][slice_range].nbytes * 2 + arguments["updates"].nbytes ) * iters bandwidth_gb_s = bytes_processed / runtime / 1e6 return runtime, bandwidth_gb_s def benchmark_slice_update_torch( dst_shape, slice_shape, slice_range, device, dtype, iters=10 ): def slice_update(dst, updates, slice_range): for i in range(iters): dst[slice_range] = dst[slice_range] + updates if device == torch.device("mps"): torch.mps.synchronize() dtype = getattr(torch, dtype) updates = torch.randn(slice_shape, dtype=dtype).to(device) dst = torch.randn(dst_shape, dtype=dtype).to(device) runtime = measure_runtime( slice_update, dst=dst, updates=updates, slice_range=slice_range ) bytes_processed = (dst[slice_range].nbytes * 2 + updates.nbytes) * iters bandwidth_gb_s = bytes_processed / runtime / 1e6 return runtime, bandwidth_gb_s if __name__ == "__main__": parser = argparse.ArgumentParser("Slice update benchmarks.") parser.add_argument("--cpu", action="store_true", help="Use the CPU.") args = parser.parse_args() if args.cpu: mx.set_default_device(mx.cpu) device = torch.device("cpu") elif torch.mps.is_available(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: raise ValueError() dtypes = ["float32", "bfloat16"] test_cases = [ ((10_000_000,), slice(0, 1_000_000), (1_000_000,)), ((100_000,), slice(10_000, 20_000), (10_000,)), ((1000, 64), slice(100, 200), (100, 64)), ((100, 100, 64), slice(20, 40), (20, 100, 64)), ( (2048, 2048, 128), (slice(500, 1500), slice(200, 1200), slice(32, 96)), (1000, 1000, 64), ), ( (2048, 2048, 128), (slice(1800, 1850), slice(100, 200), slice(64, 128)), (50, 100, 64), ), ( (2048, 2048, 128), (slice(1000, 1010), slice(1000, 1010), slice(64, 128)), (10, 10, 64), ), ] print( f"{'Dtype':<12} {'Dst Shape':<25} {'Update Shape':<20} " f"{'MLX (ms)':<12} {'MLX GB/s':<12} {'Torch (ms)':<12} {'Torch GB/s':<12}" ) print("-" * 110) for dtype in dtypes: for dst_shape, slice_range, update_shape in test_cases: mlx_time, mlx_bw = benchmark_slice_update_mlx( dst_shape, update_shape, slice_range, dtype ) torch_time, torch_bw = benchmark_slice_update_torch( dst_shape, update_shape, slice_range, device, dtype ) print( f"{dtype:<12} {str(dst_shape):<25} {str(update_shape):<20} " f"{mlx_time:<12.3f} {mlx_bw:<12.2f} {torch_time:<12.3f} {torch_bw:<12.2f}" ) ================================================ FILE: benchmarks/python/synchronize_bench.py ================================================ import time import mlx.core as mx rank = mx.distributed.init().rank() def timeit(fn, a): # warmup for _ in range(5): mx.eval(fn(a)) its = 10 tic = time.perf_counter() for _ in range(its): mx.eval(fn(a)) toc = time.perf_counter() ms = 1000 * (toc - tic) / its return ms def all_reduce_benchmark(): a = mx.ones((5, 5), mx.int32) its_per_eval = 100 def fn(x): for _ in range(its_per_eval): x = mx.distributed.all_sum(x) x = x - 1 return x ms = timeit(fn, a) / its_per_eval if rank == 0: print(f"All Reduce: time per iteration {ms:.6f} (ms)") def all_gather_benchmark(): a = mx.ones((5, 5), mx.int32) its_per_eval = 100 def fn(x): for _ in range(its_per_eval): x = mx.distributed.all_gather(x)[0] return x ms = timeit(fn, a) / its_per_eval if rank == 0: print(f"All gather: time per iteration {ms:.6f} (ms)") if __name__ == "__main__": all_reduce_benchmark() all_gather_benchmark() ================================================ FILE: benchmarks/python/time_utils.py ================================================ # Copyright © 2023-2024 Apple Inc. import time import mlx.core as mx def time_fn(fn, *args, **kwargs): msg = kwargs.pop("msg", None) if msg: print(f"Timing {msg} ...", end=" ") else: print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): mx.eval(fn(*args, **kwargs)) num_iters = 100 tic = time.perf_counter() for _ in range(num_iters): x = mx.eval(fn(*args, **kwargs)) toc = time.perf_counter() msec = 1e3 * (toc - tic) / num_iters print(f"{msec:.5f} msec") def measure_runtime(fn, **kwargs): # Warmup for _ in range(5): fn(**kwargs) tic = time.perf_counter() iters = 100 for _ in range(iters): fn(**kwargs) return (time.perf_counter() - tic) * 1000 / iters ================================================ FILE: cmake/FindCUDNN.cmake ================================================ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # Modified from # https://github.com/NVIDIA/cudnn-frontend/blob/main/cmake/cuDNN.cmake # Return the last file matching the pattern. function(find_file_glob VAR PATTERN) file(GLOB _RESULT "${PATTERN}") if(_RESULT) list(LENGTH ${_RESULT} _RESULT_LENGTH) if(_RESULT_LENGTH GREATER 0) list(GET ${_RESULT} -1 _RESULT) endif() set(${VAR} "${_RESULT}" PARENT_SCOPE) endif() endfunction() # Find the dir including the "cudnn.h" file. find_path( CUDNN_INCLUDE_DIR cudnn.h HINTS ${CUDNN_INCLUDE_PATH} ${CUDAToolkit_INCLUDE_DIRS} PATH_SUFFIXES include OPTIONAL) # Glob searching "cudnn.h" for Windows. if(WIN32 AND NOT CUDNN_INCLUDE_DIR) find_file_glob( CUDNN_H_PATH "C:/Program Files/NVIDIA/CUDNN/*/include/${CUDAToolkit_VERSION_MAJOR}.*/cudnn.h" ) if(CUDNN_H_PATH) get_filename_component(CUDNN_INCLUDE_DIR "${CUDNN_H_PATH}" DIRECTORY) endif() endif() if(NOT CUDNN_INCLUDE_DIR) message( FATAL_ERROR "Unable to find cudnn.h, please make sure cuDNN is installed and pass CUDNN_INCLUDE_PATH to cmake." ) endif() # Get cudnn version. file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header) string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}") string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") # Function for searching library files. function(find_cudnn_library NAME) if(NOT "${ARGV1}" STREQUAL "OPTIONAL") set(_CUDNN_REQUIRED TRUE) else() set(_CUDNN_REQUIRED FALSE) endif() find_library( ${NAME}_LIBRARY NAMES ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" NAMES_PER_DIR HINTS ${CUDNN_LIBRARY_PATH} ${CUDAToolkit_LIBRARY_DIR} PATH_SUFFIXES lib64 lib/x64 lib OPTIONAL) if(WIN32 AND NOT ${NAME}_LIBRARY) find_file_glob( ${NAME}_LIBRARY "C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDAToolkit_VERSION_MAJOR}.*/x64/${NAME}.lib" ) endif() if(NOT ${NAME}_LIBRARY AND ${_CUDNN_REQUIRED}) message( FATAL_ERROR "Unable to find ${NAME}, please make sure cuDNN is installed and pass CUDNN_LIBRARY_PATH to cmake." ) endif() if(${NAME}_LIBRARY) add_library(CUDNN::${NAME} UNKNOWN IMPORTED) set_target_properties( CUDNN::${NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} IMPORTED_LOCATION ${${NAME}_LIBRARY}) set(${NAME}_LIBRARY "${${NAME}_LIBRARY}" PARENT_SCOPE) else() message(STATUS "${NAME} not found.") endif() endfunction() # Search for the main cudnn library. find_cudnn_library(cudnn) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(CUDNN REQUIRED_VARS CUDNN_INCLUDE_DIR cudnn_LIBRARY) if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY) set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") else() set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found") endif() # Find out all the DLL files for Windows. if(WIN32 AND cudnn_LIBRARY) get_filename_component(CUDNN_BIN_DIR "${cudnn_LIBRARY}" DIRECTORY) string(REPLACE "/lib/" "/bin/" CUDNN_BIN_DIR "${CUDNN_BIN_DIR}") file( GLOB CUDNN_DLL_NAMES RELATIVE "${CUDNN_BIN_DIR}" "${CUDNN_BIN_DIR}/*.dll") endif() # Create an interface library that users can link with. add_library(CUDNN::cudnn_all INTERFACE IMPORTED) target_link_libraries(CUDNN::cudnn_all INTERFACE CUDNN::cudnn) target_include_directories( CUDNN::cudnn_all INTERFACE $ $) # Add other components of cudnn. if(CUDNN_MAJOR_VERSION EQUAL 8) find_cudnn_library(cudnn_adv_infer) find_cudnn_library(cudnn_adv_train) find_cudnn_library(cudnn_cnn_infer) find_cudnn_library(cudnn_cnn_train) find_cudnn_library(cudnn_ops_infer) find_cudnn_library(cudnn_ops_train) target_link_libraries( CUDNN::cudnn_all INTERFACE CUDNN::cudnn_adv_train CUDNN::cudnn_ops_train CUDNN::cudnn_cnn_train CUDNN::cudnn_adv_infer CUDNN::cudnn_cnn_infer CUDNN::cudnn_ops_infer) elseif(CUDNN_MAJOR_VERSION EQUAL 9) find_cudnn_library(cudnn_graph) find_cudnn_library(cudnn_engines_runtime_compiled) find_cudnn_library(cudnn_ops OPTIONAL) find_cudnn_library(cudnn_cnn OPTIONAL) find_cudnn_library(cudnn_adv OPTIONAL) find_cudnn_library(cudnn_engines_precompiled OPTIONAL) find_cudnn_library(cudnn_heuristic OPTIONAL) target_link_libraries( CUDNN::cudnn_all INTERFACE CUDNN::cudnn_graph CUDNN::cudnn_engines_runtime_compiled CUDNN::cudnn_ops CUDNN::cudnn_cnn CUDNN::cudnn_adv CUDNN::cudnn_engines_precompiled CUDNN::cudnn_heuristic) endif() ================================================ FILE: cmake/FindNCCL.cmake ================================================ # FindNCCL.cmake This module finds the NVIDIA NCCL library and its include # directories. set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} CACHE PATH "Folder contains NVIDIA NCCL") find_path( NCCL_INCLUDE_DIRS NAMES nccl.h HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include ${CUDA_TOOLKIT_ROOT_DIR}/include) if($ENV{USE_STATIC_NCCL}) message( STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") set(NCCL_LIBNAME "libnccl_static.a") else() set(NCCL_LIBNAME "nccl") endif() find_library( NCCL_LIBRARIES NAMES ${NCCL_LIBNAME} HINTS ${NCCL_LIB_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/lib ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu ${NCCL_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) if(NCCL_FOUND) set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") message( STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") file( STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) if(NCCL_MAJOR_VERSION_DEFINED) string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") endif() message( STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) endif() ================================================ FILE: cmake/Findnvpl.cmake ================================================ # This file does nothing but to suppress the cmake warning: "By not providing # Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the # find_package(nvpl) from cmake's builtin FindLAPACK.cmake module. ================================================ FILE: cmake/extension.cmake ================================================ include(CMakeParseArguments) # clang format off # # ############################################################################## # Build metal library # # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # # Args: TARGET: Custom target to be added for the metal library TITLE: Name of # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # files (like headers) DEBUG: Boolean, if true, enables debug compile options # for this specific library. If not provided, uses global MLX_METAL_DEBUG. # # clang format on macro(mlx_build_metallib) # Parse args set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) # Set output set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") # Collect compile options set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) if(MLX_METAL_DEBUG OR MTLLIB_DEBUG) set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only -frecord-sources) endif() # Prepare metallib build command add_custom_command( OUTPUT ${MTLLIB_BUILD_TARGET} COMMAND xcrun -sdk macosx metal "$" ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} COMMAND_EXPAND_LISTS COMMENT "Building ${MTLLIB_TITLE}.metallib" VERBATIM) # Add metallib custom target add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET}) endmacro(mlx_build_metallib) ================================================ FILE: docs/.clang-format ================================================ DisableFormat: true SortIncludes: Never ================================================ FILE: docs/.gitignore ================================================ src/python/_autosummary*/ src/python/nn/_autosummary*/ src/python/optimizers/_autosummary*/ ================================================ FILE: docs/.nojekyll ================================================ ================================================ FILE: docs/Doxyfile ================================================ ################################################################################ # Primary project setup. # ################################################################################ PROJECT_NAME = "MLX" OUTPUT_DIRECTORY = build XML_OUTPUT = xml HTML_OUTPUT = html STRIP_FROM_PATH = ../ INPUT = ../mlx FILE_PATTERNS = *.h EXCLUDE_PATTERNS = */private/* CREATE_SUBDIRS = NO FULL_PATH_NAMES = YES RECURSIVE = YES GENERATE_HTML = NO GENERATE_LATEX = NO GENERATE_XML = YES XML_PROGRAMLISTING = YES ################################################################################ # Doxygen preprocessor / parser control. # ################################################################################ ENABLE_PREPROCESSING = YES MACRO_EXPANSION = YES EXPAND_ONLY_PREDEF = NO SKIP_FUNCTION_MACROS = NO PREDEFINED = MLX_API= ################################################################################ # Compound extraction control. # ################################################################################ EXTRACT_ALL = YES EXTRACT_PACKAGE = YES EXTRACT_STATIC = YES CASE_SENSE_NAMES = NO ################################################################################ # Docstring control / customization. # ################################################################################ JAVADOC_AUTOBRIEF = YES ################################################################################ # Warning suppression. # ################################################################################ QUIET = YES WARN_IF_UNDOCUMENTED = NO ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = src BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/README.md ================================================ ## Build the Docs ### Setup (do once) Install Doxygen: ``` brew install doxygen ``` Install Python packages: ``` pip install -r requirements.txt ``` ### Build Build the docs from `mlx/docs/` ``` doxygen && make html ``` View the docs by running a server in `mlx/docs/build/html/`: ``` python -m http.server ``` and point your browser to `http://localhost:`. ### Push to GitHub Pages Check-out the `gh-pages` branch (`git switch gh-pages`) and build the docs. Then force add the `build/html` directory: `git add -f build/html` Commit and push the changes to the `gh-pages` branch. ## Doc Development Setup To enable live refresh of docs while writing: Install sphinx autobuild ``` pip install sphinx-autobuild ``` Run auto build on docs/src folder ``` sphinx-autobuild ./src ./build/html ``` ================================================ FILE: docs/index.html ================================================ ================================================ FILE: docs/requirements.txt ================================================ sphinx breathe sphinx-book-theme sphinx-copybutton mlx ================================================ FILE: docs/src/_templates/module-base-class.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. add toctree option to make autodoc generate the pages .. autoclass:: {{ objname }} {% block attributes %} {% if attributes %} .. rubric:: Attributes .. autosummary:: :toctree: . {% for item in attributes %} ~{{ fullname }}.{{ item }} {%- endfor %} {% endif %} {% endblock %} {% block methods %} {% if methods %} .. rubric:: Methods .. autosummary:: :toctree: . {% for item in methods %} {%- if item not in inherited_members and item != '__init__' %} ~{{ fullname }}.{{ item }} {%- endif -%} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/src/_templates/nn-module-template.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} {% block methods %} {% if methods %} .. rubric:: {{ _('Methods') }} .. autosummary:: {% for item in methods %} {%- if item not in inherited_members and item != "__init__" %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/src/_templates/optimizers-template.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} {% block methods %} {% if methods %} .. rubric:: {{ _('Methods') }} .. autosummary:: {% for item in methods %} {%- if item not in inherited_members %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/src/conf.py ================================================ # Copyright © 2023 Apple Inc. # -*- coding: utf-8 -*- import os import subprocess import mlx.core as mx # -- Project information ----------------------------------------------------- project = "MLX" copyright = "2023, Apple" author = "MLX Contributors" version = ".".join(mx.__version__.split(".")[:3]) release = version # -- General configuration --------------------------------------------------- extensions = [ "sphinx_copybutton", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.napoleon", "breathe", ] python_use_unqualified_type_names = True autosummary_generate = True autosummary_filename_map = {"mlx.core.Stream": "stream_class"} intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), } breathe_projects = {"mlx": "../build/xml"} breathe_default_project = "mlx" templates_path = ["_templates"] html_static_path = ["_static"] source_suffix = ".rst" main_doc = "index" highlight_language = "python" pygments_style = "sphinx" add_module_names = False # -- Options for HTML output ------------------------------------------------- html_theme = "sphinx_book_theme" html_theme_options = { "show_toc_level": 2, "repository_url": "https://github.com/ml-explore/mlx", "use_repository_button": True, "navigation_with_keys": False, "logo": { "image_light": "_static/mlx_logo.png", "image_dark": "_static/mlx_logo_dark.png", }, } html_favicon = html_theme_options["logo"]["image_light"] # -- Options for HTMLHelp output --------------------------------------------- htmlhelp_basename = "mlx_doc" def setup(app): from sphinx.util import inspect wrapped_isfunc = inspect.isfunction def isfunc(obj): type_name = str(type(obj)) if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name: return True return wrapped_isfunc(obj) inspect.isfunction = isfunc # -- Options for LaTeX output ------------------------------------------------ latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")] latex_elements = { "preamble": r""" \usepackage{enumitem} \setlistdepth{5} \setlist[itemize,1]{label=$\bullet$} \setlist[itemize,2]{label=$\bullet$} \setlist[itemize,3]{label=$\bullet$} \setlist[itemize,4]{label=$\bullet$} \setlist[itemize,5]{label=$\bullet$} \renewlist{itemize}{itemize}{5} """, } ================================================ FILE: docs/src/cpp/ops.rst ================================================ .. _cpp_ops: Operations ========== .. doxygengroup:: ops :content-only: ================================================ FILE: docs/src/dev/custom_metal_kernels.rst ================================================ .. _custom_metal_kernels: Custom Metal Kernels ==================== MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- .. currentmodule:: mlx.core Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python source = """ uint elem = thread_position_in_grid.x; T tmp = inp[elem]; out[elem] = metal::exp(tmp); """ kernel = mx.fast.metal_kernel( name="myexp", input_names=["inp"], output_names=["out"], source=source, ) def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[a.shape], output_dtypes=[a.dtype], ) return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) Every time you make a kernel, a new Metal library is created and possibly JIT compiled. To reduce the overhead from that, build the kernel once with :func:`fast.metal_kernel` and then use it many times. .. note:: Only pass the body of the Metal kernel in ``source``. The function signature is generated automatically. The full function signature will be generated using: * The shapes/dtypes of ``inputs`` In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` so we will add ``const device float16_t* inp`` to the signature. ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present in ``source``. * The list of ``output_dtypes`` In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` so we add ``device float16_t* out``. * Template parameters passed using ``template`` In the above, ``template=[("T", mx.float32)]`` adds a template of ``template `` to the function and instantiates the template with ``custom_kernel_myexp_float``. Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. * Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` These will be added as function arguments. All the attributes defined in Table 5.8 of the `Metal Shading Language Specification `_ are supported. Putting this all together, the generated function signature for ``myexp`` is as follows: .. code-block:: cpp template [[kernel]] void custom_kernel_myexp_float( const device float16_t* inp [[buffer(0)]], device float16_t* out [[buffer(1)]], uint3 thread_position_in_grid [[thread_position_in_grid]]) { uint elem = thread_position_in_grid.x; T tmp = inp[elem]; out[elem] = metal::exp(tmp); } template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the generated code for debugging purposes. Using Shape/Strides ------------------- :func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. This will copy the array inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims when indexing. If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are present in ``source``. We can then use MLX's built in indexing utils to fetch the right elements for each thread. Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: .. code-block:: python source = """ uint elem = thread_position_in_grid.x; // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); T tmp = inp[loc]; // Output arrays are always row contiguous out[elem] = metal::exp(tmp); """ kernel = mx.fast.metal_kernel( name="myexp_strided", input_names=["inp"], output_names=["out"], source=source, ensure_row_contiguous=False, ) def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[a.shape], output_dtypes=[a.dtype], ) return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) # make non-contiguous a = a[::2] b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) Complex Example ----------------------------- Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode. We'll start with the following MLX implementation using standard ops: .. code-block:: python def grid_sample_ref(x, grid): N, H_in, W_in, _ = x.shape ix = ((grid[..., 0] + 1) * W_in - 1) / 2 iy = ((grid[..., 1] + 1) * H_in - 1) / 2 ix_nw = mx.floor(ix).astype(mx.int32) iy_nw = mx.floor(iy).astype(mx.int32) ix_ne = ix_nw + 1 iy_ne = iy_nw ix_sw = ix_nw iy_sw = iy_nw + 1 ix_se = ix_nw + 1 iy_se = iy_nw + 1 nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) I_nw *= mask_nw[..., None] I_ne *= mask_ne[..., None] I_sw *= mask_sw[..., None] I_se *= mask_se[..., None] output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se return output Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python source = """ uint elem = thread_position_in_grid.x; int H = x_shape[1]; int W = x_shape[2]; int C = x_shape[3]; int gH = grid_shape[1]; int gW = grid_shape[2]; int w_stride = C; int h_stride = W * w_stride; int b_stride = H * h_stride; uint grid_idx = elem / C * 2; float ix = ((grid[grid_idx] + 1) * W - 1) / 2; float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; int ix_nw = floor(ix); int iy_nw = floor(iy); int ix_ne = ix_nw + 1; int iy_ne = iy_nw; int ix_sw = ix_nw; int iy_sw = iy_nw + 1; int ix_se = ix_nw + 1; int iy_se = iy_nw + 1; T nw = (ix_se - ix) * (iy_se - iy); T ne = (ix - ix_sw) * (iy_sw - iy); T sw = (ix_ne - ix) * (iy - iy_ne); T se = (ix - ix_nw) * (iy - iy_nw); int batch_idx = elem / C / gH / gW * b_stride; int channel_idx = elem % C; int base_idx = batch_idx + channel_idx; T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; """ kernel = mx.fast.metal_kernel( name="grid_sample", input_names=["x", "grid"], output_names=["out"], source=source, ) @mx.custom_function def grid_sample(x, grid): assert x.ndim == 4, "`x` must be 4D." assert grid.ndim == 4, "`grid` must be 4D." B, _, _, C = x.shape _, gN, gM, D = grid.shape out_shape = (B, gN, gM, C) assert D == 2, "Last dim of `grid` must be size 2." outputs = kernel( inputs=[x, grid], template=[("T", x.dtype)], output_shapes=[out_shape], output_dtypes=[x.dtype], grid=(np.prod(out_shape), 1, 1), threadgroup=(256, 1, 1), ) return outputs[0] For a reasonably sized input such as: .. code-block:: python x.shape = (8, 1024, 1024, 64) grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: ``55.7ms -> 6.7ms => 8x speed up`` Grid Sample VJP --------------- Since we decorated ``grid_sample`` with :func:`custom_function`, we can now define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. * ``atomic_outputs=True`` Designate all of the kernel outputs as ``atomic`` in the function signature. This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups. See section 6.15 of the `Metal Shading Language Specification `_ for more details. We can then implement the backwards pass as follows: .. code-block:: python source = """ uint elem = thread_position_in_grid.x; int H = x_shape[1]; int W = x_shape[2]; int C = x_shape[3]; // Pad C to the nearest larger simdgroup size multiple int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; int gH = grid_shape[1]; int gW = grid_shape[2]; int w_stride = C; int h_stride = W * w_stride; int b_stride = H * h_stride; uint grid_idx = elem / C_padded * 2; float ix = ((grid[grid_idx] + 1) * W - 1) / 2; float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; int ix_nw = floor(ix); int iy_nw = floor(iy); int ix_ne = ix_nw + 1; int iy_ne = iy_nw; int ix_sw = ix_nw; int iy_sw = iy_nw + 1; int ix_se = ix_nw + 1; int iy_se = iy_nw + 1; T nw = (ix_se - ix) * (iy_se - iy); T ne = (ix - ix_sw) * (iy_sw - iy); T sw = (ix_ne - ix) * (iy - iy_ne); T se = (ix - ix_nw) * (iy - iy_nw); int batch_idx = elem / C_padded / gH / gW * b_stride; int channel_idx = elem % C_padded; int base_idx = batch_idx + channel_idx; T gix = T(0); T giy = T(0); if (channel_idx < C) { int cot_index = elem / C_padded * C + channel_idx; T cot = cotangent[cot_index]; if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); T I_nw = x[offset]; gix -= I_nw * (iy_se - iy) * cot; giy -= I_nw * (ix_se - ix) * cot; } if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); T I_ne = x[offset]; gix += I_ne * (iy_sw - iy) * cot; giy -= I_ne * (ix - ix_sw) * cot; } if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); T I_sw = x[offset]; gix -= I_sw * (iy - iy_ne) * cot; giy += I_sw * (ix_ne - ix) * cot; } if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { int offset = base_idx + iy_se * h_stride + ix_se * w_stride; atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); T I_se = x[offset]; gix += I_se * (iy - iy_nw) * cot; giy += I_se * (ix - ix_nw) * cot; } } T gix_mult = W / 2; T giy_mult = H / 2; // Reduce across each simdgroup first. // This is much faster than relying purely on atomics. gix = simd_sum(gix); giy = simd_sum(giy); if (thread_index_in_simdgroup == 0) { atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); } """ kernel = mx.fast.metal_kernel( name="grid_sample_grad", input_names=["x", "grid", "cotangent"], output_names=["x_grad", "grid_grad"], source=source, atomic_outputs=True, ) @grid_sample.vjp def grid_sample_vjp(primals, cotangent, _): x, grid = primals B, _, _, C = x.shape _, gN, gM, D = grid.shape assert D == 2, "Last dim of `grid` must be size 2." # pad the output channels to simd group size # so that our `simd_sum`s don't overlap. simdgroup_size = 32 C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size grid_size = B * gN * gM * C_padded outputs = kernel( inputs=[x, grid, cotangent], template=[("T", x.dtype)], output_shapes=[x.shape, grid.shape], output_dtypes=[x.dtype, x.dtype], grid=(grid_size, 1, 1), threadgroup=(256, 1, 1), init_value=0, ) return outputs[0], outputs[1] There's an even larger speed up for the vjp: ``676.4ms -> 16.7ms => 40x speed up`` ================================================ FILE: docs/src/dev/extensions.rst ================================================ Custom Extensions in MLX ======================== You can extend MLX with custom operations on the CPU or GPU. This guide explains how to do that with a simple example. Introducing the Example ----------------------- Let's say you would like an operation that takes in two arrays, ``x`` and ``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively, and then adds them together to get the result ``z = alpha * x + beta * y``. You can do that in MLX directly: .. code-block:: python import mlx.core as mx def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: return alpha * x + beta * y This function performs that operation while leaving the implementation and function transformations to MLX. However, you may want to customize the underlying implementation, perhaps to make it faster. In this tutorial we will go through adding custom extensions. It will cover: * The structure of the MLX library. * Implementing a CPU operation. * Implementing a GPU operation using metal. * Adding the ``vjp`` and ``jvp`` function transformation. * Building a custom extension and binding it to python. Operations and Primitives ------------------------- Operations in MLX build the computation graph. Primitives provide the rules for evaluating and transforming the graph. Let's start by discussing operations in more detail. Operations ^^^^^^^^^^^ Operations are the front-end functions that operate on arrays. They are defined in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in C++: .. code-block:: C++ /** * Scale and sum two vectors element-wise * z = alpha * x + beta * y * * Use NumPy-style broadcasting between x and y * Inputs are upcasted to floats if needed **/ array axpby( const array& x, // Input array x const array& y, // Input array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y StreamOrDevice s = {} // Stream on which to schedule the operation ); The simplest way to implement this is with existing operations: .. code-block:: C++ array axpby( const array& x, // Input array x const array& y, // Input array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { // Scale x and y on the provided stream auto ax = multiply(array(alpha), x, s); auto by = multiply(array(beta), y, s); // Add and return return add(ax, by, s); } The operations themselves do not contain the implementations that act on the data, nor do they contain the rules of transformations. Rather, they are an easy to use interface that use :class:`Primitive` building blocks. Primitives ^^^^^^^^^^^ A :class:`Primitive` is part of the computation graph of an :class:`array`. It defines how to create output arrays given input arrays. Further, a :class:`Primitive` has methods to run on the CPU or GPU and for function transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be more concrete: .. code-block:: C++ class Axpby : public Primitive { public: explicit Axpby(Stream stream, float alpha, float beta) : Primitive(stream), alpha_(alpha), beta_(beta){}; /** * A primitive must know how to evaluate itself on the CPU/GPU * for the given inputs and populate the output array. * * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ void eval_cpu( const std::vector& inputs, std::vector& outputs) override; void eval_gpu( const std::vector& inputs, std::vector& outputs) override; /** The Jacobian-vector product. */ std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; /** The vector-Jacobian product. */ std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; /** * The primitive must know how to vectorize itself across * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; /** The name of primitive. */ const char* name() const override { return "Axpby"; } /** Equivalence check **/ bool is_equivalent(const Primitive& other) const override; private: float alpha_; float beta_; }; The :class:`Axpby` class derives from the base :class:`Primitive` class. The :class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides implementations of how the output array is produced given the inputs through :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. Using the Primitive ^^^^^^^^^^^^^^^^^^^ Operations can use this :class:`Primitive` to add a new :class:`array` to the computation graph. An :class:`array` can be constructed by providing its data type, shape, the :class:`Primitive` that computes it, and the :class:`array` inputs that are passed to the primitive. Let's reimplement our operation now in terms of our :class:`Axpby` primitive. .. code-block:: C++ array axpby( const array& x, // Input array x const array& y, // Input array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { // Promote dtypes between x and y as needed auto promoted_dtype = promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs x and y auto out_dtype = issubdtype(promoted_dtype, float32) ? promoted_dtype : promote_types(promoted_dtype, float32); // Cast x and y up to the determined dtype (on the same stream s) auto x_casted = astype(x, out_dtype, s); auto y_casted = astype(y, out_dtype, s); // Broadcast the shapes of x and y (on the same stream s) auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); auto out_shape = broadcasted_inputs[0].shape(); // Construct the array as the output of the Axpby primitive // with the broadcasted and upcasted arrays as inputs return array( /* const std::vector& shape = */ out_shape, /* Dtype dtype = */ out_dtype, /* std::unique_ptr primitive = */ std::make_shared(to_stream(s), alpha, beta), /* const std::vector& inputs = */ broadcasted_inputs); } This operation now handles the following: #. Upcast inputs and resolve the output data type. #. Broadcast the inputs and resolve the output shape. #. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``. #. Construct the output :class:`array` using the primitive and the inputs. Implementing the Primitive -------------------------- No computation happens when we call the operation alone. The operation only builds the computation graph. When we evaluate the output array, MLX schedules the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the stream/device specified by the user. .. warning:: When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, no memory has been allocated for the output array. It falls on the implementation of these functions to allocate memory as needed. Implementing the CPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let's start by implementing :meth:`Axpby::eval_cpu`. The method will go over each element of the output array, find the corresponding input elements of ``x`` and ``y`` and perform the operation point-wise. This is captured in the templated function :meth:`axpby_impl`. .. code-block:: C++ template void axpby_impl( const mx::array& x, const mx::array& y, mx::array& out, float alpha_, float beta_, mx::Stream stream) { out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); encoder.set_input_array(x); encoder.set_input_array(y); encoder.set_output_array(out); // Launch the CPU kernel encoder.dispatch([x_ptr = x.data(), y_ptr = y.data(), out_ptr = out.data(), size = out.size(), shape = out.shape(), x_strides = x.strides(), y_strides = y.strides(), alpha_, beta_]() { // Cast alpha and beta to the relevant types T alpha = static_cast(alpha_); T beta = static_cast(beta_); // Do the element-wise operation for each output for (size_t out_idx = 0; out_idx < size; out_idx++) { // Map linear indices to offsets in x and y auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); // We allocate the output to be contiguous and regularly strided // (defaults to row major) and hence it doesn't need additional mapping out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; } }); } Our implementation should work for all incoming floating point arrays. Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error if we encounter an unexpected type. .. code-block:: C++ void Axpby::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == mx::float32) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::float16) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::bfloat16) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::complex64) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else { throw std::runtime_error( "Axpby is only supported for floating point types."); } } Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If you do not plan on running the operation on the GPU or using transforms on computation graphs that contain :class:`Axpby`, you can stop implementing the primitive here. Implementing the GPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Apple silicon devices address their GPUs using the Metal_ shading language, and GPU kernels in MLX are written using Metal. .. note:: Here are some helpful resources if you are new to Metal: * A walkthrough of the metal compute pipeline: `Metal Example`_ * Documentation for metal shading language: `Metal Specification`_ * Using metal from C++: `Metal-cpp`_ Let's keep the GPU kernel simple. We will launch exactly as many threads as there are elements in the output. Each thread will pick the element it needs from ``x`` and ``y``, do the point-wise operation, and update its assigned element in the output. .. code-block:: C++ template [[kernel]] void axpby_general( device const T* x [[buffer(0)]], device const T* y [[buffer(1)]], device T* out [[buffer(2)]], constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], constant const int* shape [[buffer(5)]], constant const int64_t* x_strides [[buffer(6)]], constant const int64_t* y_strides [[buffer(7)]], constant const int& ndim [[buffer(8)]], uint index [[thread_position_in_grid]]) { // Convert linear indices to offsets in array auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim); // Do the operation and update the output out[index] = static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; } We then need to instantiate this template for all floating point types and give each instantiation a unique host name so we can identify it. .. code-block:: C++ instantiate_kernel("axpby_general_float32", axpby_general, float) instantiate_kernel("axpby_general_float16", axpby_general, float16_t) instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t) instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t) The logic to determine the kernel, set the inputs, resolve the grid dimensions, and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown below. .. code-block:: C++ /** Evaluate primitive on GPU */ void Axpby::eval_gpu( const std::vector& inputs, std::vector& outputs) { // Prepare inputs assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Each primitive carries the stream it should execute on // and each stream carries its device identifiers auto& s = stream(); // We get the needed metal device using the stream auto& d = metal::device(s.device); // Allocate output memory out.set_data(allocator::malloc(out.nbytes())); // Resolve name of kernel std::stream kname; kname = "axpby_general_" + type_to_name(out); // Load the metal library auto lib = d.get_library("mlx_ext", current_binary_dir()); // Make a kernel from this metal library auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to // those in the kernel declaration at axpby.metal int ndim = out.ndim(); size_t nelem = out.size(); // Encode input arrays to kernel compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(y, 1); // Encode output arrays to kernel compute_encoder.set_output_array(out, 2); // Encode alpha and beta compute_encoder.set_bytes(alpha_, 3); compute_encoder.set_bytes(beta_, 4); // Encode shape, strides and ndim compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder.set_bytes(y.strides(), 7); compute_encoder.set_bytes(ndim, 8); // We launch 1 thread for each input and make sure that the number of // threads in any given threadgroup is not higher than the max allowed size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup()); // Fix the 3D size of each threadgroup (in terms of threads) MTL::Size group_dims = MTL::Size(tgp_size, 1, 1); // Fix the 3D size of the launch grid (in terms of threads) MTL::Size grid_dims = MTL::Size(nelem, 1, 1); // Launch the grid with the given number of threads divided among // the given threadgroups compute_encoder.dispatch_threads(grid_dims, group_dims); } We can now call the :meth:`axpby` operation on both the CPU and the GPU! A few things to note about MLX and Metal before moving on. MLX keeps track of the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is associated. We rely on :meth:`d.get_command_encoder` to give us the active metal compute command encoder instead of building a new one and calling :meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute pipelines) to the active command buffer until some specified limit is hit or the command buffer needs to be flushed for synchronization. Primitive Transforms ^^^^^^^^^^^^^^^^^^^^^ Next, let's add implementations for transformations in a :class:`Primitive`. These transformations can be built on top of other operations, including the one we just defined: .. code-block:: C++ /** The Jacobian-vector product. */ std::vector Axpby::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { // Forward mode diff that pushes along the tangents // The jvp transform on the primitive can be built with ops // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the // jvp is just the tangent scaled by alpha // Similarly, if argnums = {1}, the jvp is just the tangent // scaled by beta if (argnums.size() > 1) { auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale_arr = array(scale, tangents[0].dtype()); return {multiply(scale_arr, tangents[0], stream())}; } // If argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; } } .. code-block:: C++ /** The vector-Jacobian product. */ std::vector Axpby::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& /* unused */) { // Reverse mode diff std::vector vjps; for (auto arg : argnums) { auto scale = arg == 0 ? alpha_ : beta_; auto scale_arr = array(scale, cotangents[0].dtype()); vjps.push_back(multiply(scale_arr, cotangents[0], stream())); } return vjps; } Note, a transformation does not need to be fully defined to start using the :class:`Primitive`. .. code-block:: C++ /** Vectorize primitive along given axis */ std::pair, std::vector> Axpby::vmap( const std::vector& inputs, const std::vector& axes) { throw std::runtime_error("[Axpby] vmap not implemented."); } Building and Binding -------------------- Let's look at the overall directory structure first. | extensions | ├── axpby | │ ├── axpby.cpp | │ ├── axpby.h | │ └── axpby.metal | ├── mlx_sample_extensions | │ └── __init__.py | ├── bindings.cpp | ├── CMakeLists.txt | └── setup.py * ``extensions/axpby/`` defines the C++ extension library * ``extensions/mlx_sample_extensions`` sets out the structure for the associated Python package * ``extensions/bindings.cpp`` provides Python bindings for our operation * ``extensions/CMakeLists.txt`` holds CMake rules to build the library and Python bindings * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install the Python package Binding to Python ^^^^^^^^^^^^^^^^^^ We use nanobind_ to build a Python API for the C++ library. Since bindings for components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are already provided, adding our :meth:`axpby` is simple. .. code-block:: C++ NB_MODULE(_ext, m) { m.doc() = "Sample extension for MLX"; m.def( "axpby", &axpby, "x"_a, "y"_a, "alpha"_a, "beta"_a, nb::kw_only(), "stream"_a = nb::none(), R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed Args: x (array): Input array. y (array): Input array. alpha (float): Scaling factor for ``x``. beta (float): Scaling factor for ``y``. Returns: array: ``alpha * x + beta * y`` )"); } Most of the complexity in the above example comes from additional bells and whistles such as the literal names and doc-strings. .. warning:: :mod:`mlx.core` must be imported before importing :mod:`mlx_sample_extensions` as defined by the nanobind module above to ensure that the casters for :mod:`mlx.core` components like :class:`mlx.core.array` are available. .. _Building with CMake: Building with CMake ^^^^^^^^^^^^^^^^^^^^ Building the C++ extension library only requires that you ``find_package(MLX CONFIG)`` and then link it to your library. .. code-block:: cmake # Add library add_library(mlx_ext) # Add sources target_sources( mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp ) # Add include headers target_include_directories( mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} ) # Link to mlx target_link_libraries(mlx_ext PUBLIC mlx) We also need to build the attached Metal library. For convenience, we provide a :meth:`mlx_build_metallib` function that builds a ``.metallib`` target given sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and automatically imported with MLX package). Here is what that looks like in practice: .. code-block:: cmake # Build metallib if(MLX_BUILD_METAL) mlx_build_metallib( TARGET mlx_ext_metallib TITLE mlx_ext SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} ) add_dependencies( mlx_ext mlx_ext_metallib ) endif() Finally, we build the nanobind_ bindings .. code-block:: cmake nanobind_add_module( _ext NB_STATIC STABLE_ABI LTO NOMINSIZE NB_DOMAIN mlx ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ) target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) endif() Building with ``setuptools`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Once we have set out the CMake build rules as described above, we can use the build utilities defined in :mod:`mlx.extension`: .. code-block:: python from mlx import extension from setuptools import setup if __name__ == "__main__": setup( name="mlx_sample_extensions", version="0.0.0", description="Sample C++ and Metal extensions for MLX primitives.", ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], cmdclass={"build_ext": extension.CMakeBuild}, packages=["mlx_sample_extensions"], package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, extras_require={"dev":[]}, zip_safe=False, python_requires=">=3.8", ) .. note:: We treat ``extensions/mlx_sample_extensions`` as the package directory even though it only contains a ``__init__.py`` to ensure the following: * :mod:`mlx.core` must be imported before importing :mod:`_ext` * The C++ extension library and the metal library are co-located with the python bindings and copied together if the package is installed To build the package, first install the build dependencies with ``pip install -r requirements.txt``. You can then build inplace for development using ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) This results in the directory structure: | extensions | ├── mlx_sample_extensions | │ ├── __init__.py | │ ├── libmlx_ext.dylib # C++ extension library | │ ├── mlx_ext.metallib # Metal library | │ └── _ext.cpython-3x-darwin.so # Python Binding | ... When you try to install using the command ``python -m pip install .`` (in ``extensions/``), the package will be installed with the same structure as ``extensions/mlx_sample_extensions`` and the C++ and Metal library will be copied along with the Python binding since they are specified as ``package_data``. Usage ----- After installing the extension as described above, you should be able to simply import the Python package and play with it as you would any other MLX operation. Let's look at a simple script and its results: .. code-block:: python import mlx.core as mx from mlx_sample_extensions import axpby a = mx.ones((3, 4)) b = mx.ones((3, 4)) c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) print(f"c shape: {c.shape}") print(f"c dtype: {c.dtype}") print(f"c is correct: {mx.all(c == 6.0).item()}") Output: .. code-block:: c shape: [3, 4] c dtype: float32 c is correct: True Results ^^^^^^^ Let's run a quick benchmark and see how our new ``axpby`` operation compares with the naive :meth:`simple_axpby` we first defined. .. code-block:: python import mlx.core as mx from mlx_sample_extensions import axpby import time def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: return alpha * x + beta * y M = 4096 N = 4096 x = mx.random.normal((M, N)) y = mx.random.normal((M, N)) alpha = 4.0 beta = 2.0 mx.eval(x, y) def bench(f): # Warm up for i in range(5): z = f(x, y, alpha, beta) mx.eval(z) # Timed run s = time.perf_counter() for i in range(100): z = f(x, y, alpha, beta) mx.eval(z) e = time.perf_counter() return 1000 * (e - s) / 100 simple_time = bench(simple_axpby) custom_time = bench(axpby) print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms") The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see modest improvements right away! This operation is now good to be used to build other operations, in :class:`mlx.nn.Module` calls, and also as a part of graph transformations like :meth:`grad`. Scripts ------- .. admonition:: Download the code The full example code is available in `mlx `_. .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal-cpp: https://developer.apple.com/metal/cpp/ .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc .. _nanobind: https://nanobind.readthedocs.io/en/latest/ ================================================ FILE: docs/src/dev/metal_debugger.rst ================================================ Metal Debugger ============== .. currentmodule:: mlx.core Profiling is a key step for performance optimization. You can build MLX with the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization workflow. The ``MLX_METAL_DEBUG`` debug option: * Records source during Metal compilation, for later inspection while debugging. * Labels Metal objects such as command queues, improving capture readability. To build with debugging enabled in Python prepend ``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call. The :func:`metal.start_capture` function initiates a capture of all MLX GPU work. .. note:: To capture a GPU trace you must run the application with ``MTL_CAPTURE_ENABLED=1``. .. code-block:: python import mlx.core as mx a = mx.random.uniform(shape=(512, 512)) b = mx.random.uniform(shape=(512, 512)) mx.eval(a, b) trace_file = "mlx_trace.gputrace" # Make sure to run with MTL_CAPTURE_ENABLED=1 and # that the path trace_file does not already exist. mx.metal.start_capture(trace_file) for _ in range(10): mx.eval(mx.add(a, b)) mx.metal.stop_capture() You can open and replay the GPU trace in Xcode. The ``Dependencies`` view has a great overview of all operations. Checkout the `Metal debugger documentation`_ for more information. .. image:: ../_static/metal_debugger/capture.png :class: dark-light Xcode Workflow -------------- You can skip saving to a path by running within Xcode. First, generate an Xcode project using CMake. .. code-block:: mkdir build && cd build cmake .. -DMLX_METAL_DEBUG=ON -G Xcode open mlx.xcodeproj Select the ``metal_capture`` example schema and run. .. image:: ../_static/metal_debugger/schema.png :class: dark-light .. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger ================================================ FILE: docs/src/dev/metal_logging.rst ================================================ Metal Logging ============= In debug builds, MLX compiles Metal kernels with ``os_log`` enabled so shader warnings and debug messages are visible during development. .. note:: Metal logging is only available with Metal 3.2 or higher (macOS 15 and up, iOS 18 and up). To enable logging from kernels, first make sure to build in debug mode: .. code-block:: bash DEBUG=1 python -m pip install -e . Then, in the kernel source code include MLX's logging shim and use ``mlx::os_log``: .. code-block:: #include "mlx/backend/metal/kernels/logging.h" constant mlx::os_log logger("mlx", "my_kernel"); kernel void my_kernel(/* ... */) { // ... logger.log_debug("unexpected state: idx=%u", idx); } When you run the program, set the Metal log level to your desired level and forward logs to ``stderr``: .. code-block:: bash MTL_LOG_LEVEL=MTLLogLevelDebug MTL_LOG_TO_STDERR=1 python script.py See the `Metal logging guide`_ for more details. .. _`Metal logging guide`: https://developer.apple.com/documentation/metal/logging-shader-debug-messages ================================================ FILE: docs/src/dev/mlx_in_cpp.rst ================================================ .. _mlx_in_cpp: Using MLX in C++ ================ You can use MLX in a C++ project with CMake. .. note:: This guide is based one the following `example using MLX in C++ `_ First install MLX: .. code-block:: bash pip install -U mlx You can also install the MLX Python package from source or just the C++ library. For more information see the :ref:`documentation on installing MLX `. Next make an example program in ``example.cpp``: .. code-block:: C++ #include #include "mlx/mlx.h" namespace mx = mlx::core; int main() { auto x = mx::array({1, 2, 3}); auto y = mx::array({1, 2, 3}); std::cout << x + y << std::endl; return 0; } The next step is to setup a CMake file in ``CMakeLists.txt``: .. code-block:: cmake cmake_minimum_required(VERSION 3.27) project(example LANGUAGES CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) Depending on how you installed MLX, you may need to tell CMake where to find it. If you installed MLX with Python, then add the following to the CMake file: .. code-block:: cmake find_package( Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE MLX_ROOT) If you installed the MLX C++ package to a system path, then CMake should be able to find it. If you installed it to a non-standard location or CMake can't find MLX then set ``MLX_ROOT`` to the location where MLX is installed: .. code-block:: cmake set(MLX_ROOT "/path/to/mlx/") Next, instruct CMake to find MLX: .. code-block:: cmake find_package(MLX CONFIG REQUIRED) Finally, add the ``example.cpp`` program as an executable and link MLX. .. code-block:: cmake add_executable(example example.cpp) target_link_libraries(example PRIVATE mlx) You can build the example with: .. code-block:: bash cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build And run it with: .. code-block:: bash ./build/example Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables: .. list-table:: Package Variables :widths: 20 20 :header-rows: 1 * - Variable - Description * - MLX_FOUND - ``True`` if MLX is found * - MLX_INCLUDE_DIRS - Include directory * - MLX_LIBRARIES - Libraries to link against * - MLX_CXX_FLAGS - Additional compiler flags * - MLX_BUILD_ACCELERATE - ``True`` if MLX was built with Accelerate * - MLX_BUILD_METAL - ``True`` if MLX was built with Metal ================================================ FILE: docs/src/examples/data_parallelism.rst ================================================ .. _data_parallelism: Data Parallelism ================ MLX enables efficient data parallel distributed training through its distributed communication primitives. .. _training_example: Training Example ---------------- In this section we will adapt an MLX training loop to support data parallel distributed training. Namely, we will average the gradients across a set of hosts before applying them to the model. Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization. .. code:: python model = ... optimizer = ... dataset = ... def step(model, x, y): loss, grads = loss_grad_fn(model, x, y) optimizer.update(model, grads) return loss for x, y in dataset: loss = step(model, x, y) mx.eval(loss, model.parameters()) All we have to do to average the gradients across machines is perform an :func:`all_sum` and divide by the size of the :class:`Group`. Namely we have to :func:`mlx.utils.tree_map` the gradients with following function. .. code:: python def all_avg(x): return mx.distributed.all_sum(x) / mx.distributed.init().size() Putting everything together our training loop step looks as follows with everything else remaining the same. .. code:: python from mlx.utils import tree_map def all_reduce_grads(grads): N = mx.distributed.init().size() if N == 1: return grads return tree_map( lambda x: mx.distributed.all_sum(x) / N, grads ) def step(model, x, y): loss, grads = loss_grad_fn(model, x, y) grads = all_reduce_grads(grads) # <--- This line was added optimizer.update(model, grads) return loss Using ``nn.average_gradients`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Although the code example above works correctly; it performs one communication per gradient. It is significantly more efficient to aggregate several gradients together and perform fewer communication steps. This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks almost identical to the example above: .. code:: python model = ... optimizer = ... dataset = ... def step(model, x, y): loss, grads = loss_grad_fn(model, x, y) grads = mx.nn.average_gradients(grads) # <---- This line was added optimizer.update(model, grads) return loss for x, y in dataset: loss = step(model, x, y) mx.eval(loss, model.parameters()) ================================================ FILE: docs/src/examples/linear_regression.rst ================================================ .. _linear_regression: Linear Regression ----------------- Let's implement a basic linear regression model as a starting point to learn MLX. First import the core package and setup some problem metadata: .. code-block:: python import mlx.core as mx num_features = 100 num_examples = 1_000 num_iters = 10_000 # iterations of SGD lr = 0.01 # learning rate for SGD We'll generate a synthetic dataset by: 1. Sampling the design matrix ``X``. 2. Sampling a ground truth parameter vector ``w_star``. 3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``. .. code-block:: python # True parameters w_star = mx.random.normal((num_features,)) # Input examples (design matrix) X = mx.random.normal((num_examples, num_features)) # Noisy labels eps = 1e-2 * mx.random.normal((num_examples,)) y = X @ w_star + eps We will use SGD to find the optimal weights. To start, define the squared loss and get the gradient function of the loss with respect to the parameters. .. code-block:: python def loss_fn(w): return 0.5 * mx.mean(mx.square(X @ w - y)) grad_fn = mx.grad(loss_fn) Start the optimization by initializing the parameters ``w`` randomly. Then repeatedly update the parameters for ``num_iters`` iterations. .. code-block:: python w = 1e-2 * mx.random.normal((num_features,)) for _ in range(num_iters): grad = grad_fn(w) w = w - lr * grad mx.eval(w) Finally, compute the loss of the learned parameters and verify that they are close to the ground truth parameters. .. code-block:: python loss = loss_fn(w) error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 print( f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, " ) # Should print something close to: Loss 0.00005, |w-w*| = 0.00364 Complete `linear regression `_ and `logistic regression `_ examples are available in the MLX GitHub repo. ================================================ FILE: docs/src/examples/llama-inference.rst ================================================ LLM inference ============== MLX enables efficient inference of large-ish transformers on Apple silicon without compromising on ease of use. In this example we will create an inference script for the Llama family of transformer models in which the model is defined in less than 200 lines of python. Implementing the model ---------------------- We will use the neural network building blocks defined in the :mod:`mlx.nn` module to concisely define the model architecture. Attention layer ^^^^^^^^^^^^^^^^ We will start with the Llama attention layer which notably uses the RoPE positional encoding. [1]_ In addition, our attention layer will optionally use a key/value cache that will be concatenated with the provided keys and values to support efficient inference. Our implementation uses :class:`mlx.nn.Linear` for all the projections and :class:`mlx.nn.RoPE` for the positional encoding. .. code-block:: python import mlx.core as mx import mlx.nn as nn class LlamaAttention(nn.Module): def __init__(self, dims: int, num_heads: int): super().__init__() self.num_heads = num_heads self.rope = nn.RoPE(dims // num_heads, traditional=True) self.query_proj = nn.Linear(dims, dims, bias=False) self.key_proj = nn.Linear(dims, dims, bias=False) self.value_proj = nn.Linear(dims, dims, bias=False) self.out_proj = nn.Linear(dims, dims, bias=False) def __call__(self, queries, keys, values, mask=None, cache=None): queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) # Extract some shapes num_heads = self.num_heads B, L, D = queries.shape # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) keys = self.rope(keys, offset=key_cache.shape[2]) keys = mx.concatenate([key_cache, keys], axis=2) values = mx.concatenate([value_cache, values], axis=2) else: queries = self.rope(queries) keys = self.rope(keys) # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) # Note that we return the keys and values to possibly be used as a cache return self.out_proj(values_hat), (keys, values) Encoder layer ^^^^^^^^^^^^^ The other component of the Llama model is the encoder layer which uses RMS normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use :class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`. .. code-block:: python class LlamaEncoderLayer(nn.Module): def __init__(self, dims: int, mlp_dims: int, num_heads: int): super().__init__() self.attention = LlamaAttention(dims, num_heads) self.norm1 = nn.RMSNorm(dims) self.norm2 = nn.RMSNorm(dims) self.linear1 = nn.Linear(dims, mlp_dims, bias=False) self.linear2 = nn.Linear(dims, mlp_dims, bias=False) self.linear3 = nn.Linear(mlp_dims, dims, bias=False) def __call__(self, x, mask=None, cache=None): y = self.norm1(x) y, cache = self.attention(y, y, y, mask, cache) x = x + y y = self.norm2(x) a = self.linear1(y) b = self.linear2(y) y = a * mx.sigmoid(a) * b y = self.linear3(y) x = x + y return x, cache Full model ^^^^^^^^^^ To implement any Llama model we simply have to combine ``LlamaEncoderLayer`` instances with an :class:`mlx.nn.Embedding` to embed the input tokens. .. code-block:: python class Llama(nn.Module): def __init__( self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int ): super().__init__() self.embedding = nn.Embedding(vocab_size, dims) self.layers = [ LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) ] self.norm = nn.RMSNorm(dims) self.out_proj = nn.Linear(dims, vocab_size, bias=False) def __call__(self, x): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(self.embedding.weight.dtype) x = self.embedding(x) for l in self.layers: x, _ = l(x, mask) x = self.norm(x) return self.out_proj(x) Note that in the implementation above we use a simple list to hold the encoder layers but using ``model.parameters()`` will still consider these layers. Generation ^^^^^^^^^^^ Our ``Llama`` module can be used for training but not inference as the ``__call__`` method above processes one input, completely ignores the cache and performs no sampling whatsoever. In the rest of this subsection, we will implement the inference function as a python generator that processes the prompt and then autoregressively yields tokens one at a time. .. code-block:: python class Llama(nn.Module): ... def generate(self, x, temp=1.0): cache = [] # Make an additive causal mask. We will need that to process the prompt. mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(self.embedding.weight.dtype) # First we process the prompt x the same way as in __call__ but # save the caches in cache x = self.embedding(x) for l in self.layers: x, c = l(x, mask=mask) cache.append(c) # <--- we store the per layer cache in a # simple python list x = self.norm(x) y = self.out_proj(x[:, -1]) # <--- we only care about the last logits # that generate the next token y = mx.random.categorical(y * (1/temp)) # y now has size [1] # Since MLX is lazily evaluated nothing is computed yet. # Calling y.item() would force the computation to happen at # this point but we can also choose not to do that and let the # user choose when to start the computation. yield y # Now we parsed the prompt and generated the first token we # need to feed it back into the model and loop to generate the # rest. while True: # Unsqueezing the last dimension to add a sequence length # dimension of 1 x = y[:, None] x = self.embedding(x) for i in range(len(cache)): # We are overwriting the arrays in the cache list. When # the computation will happen, MLX will be discarding the # old cache the moment it is not needed anymore. x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) x = self.norm(x) y = self.out_proj(x[:, -1]) y = mx.random.categorical(y * (1/temp)) yield y Putting it all together ^^^^^^^^^^^^^^^^^^^^^^^ We now have everything we need to create a Llama model and sample tokens from it. In the following code, we randomly initialize a small Llama model, process 6 tokens of prompt and generate 10 tokens. .. code-block:: python model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8) # Since MLX is lazily evaluated nothing has actually been materialized yet. # We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the # code above would still run. Let's actually materialize the model. mx.eval(model.parameters()) prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we # have a batch dimension even # though it is 1 in this case generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))] # Since we haven't evaluated anything, nothing is computed yet. The list # `generated` contains the arrays that hold the computation graph for the # full processing of the prompt and the generation of 10 tokens. # # We can evaluate them one at a time, or all together. Concatenate them or # print them. They would all result in very similar runtimes and give exactly # the same results. mx.eval(generated) Converting the weights ---------------------- This section assumes that you have access to the original Llama weights and the SentencePiece model that comes with them. We will write a small script to convert the PyTorch weights to MLX compatible ones and write them in a NPZ file that can be loaded directly by MLX. .. code-block:: python import argparse from itertools import starmap import numpy as np import torch def map_torch_to_mlx(key, value): if "tok_embedding" in key: key = "embedding.weight" elif "norm" in key: key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: key = key.replace("wq", "query_proj") key = key.replace("wk", "key_proj") key = key.replace("wv", "value_proj") key = key.replace("wo", "out_proj") elif "w1" in key or "w2" in key or "w3" in key: # The FFN is a separate submodule in PyTorch key = key.replace("feed_forward.w1", "linear1") key = key.replace("feed_forward.w3", "linear2") key = key.replace("feed_forward.w2", "linear3") elif "output" in key: key = key.replace("output", "out_proj") elif "rope" in key: return None, None return key, value.numpy() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument("torch_weights") parser.add_argument("output_file") args = parser.parse_args() state = torch.load(args.torch_weights) np.savez( args.output_file, **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} ) Weight loading and benchmarking ------------------------------- After converting the weights to be compatible to our implementation, all that is left is to load them from disk and we can finally use the LLM to generate text. We can load numpy format files using the :func:`mlx.core.load` operation. To create a parameter dictionary from the key/value representation of NPZ files we will use the :func:`mlx.utils.tree_unflatten` helper method as follows: .. code-block:: python from mlx.utils import tree_unflatten model.update(tree_unflatten(list(mx.load(weight_file).items()))) :meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look like ``layers.2.attention.query_proj.weight`` and will transform them to .. code-block:: python {"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]} which can then be used to update the model. Note that the method above incurs several unnecessary copies from disk to numpy and then from numpy to MLX. It will be replaced in the future with direct loading to MLX. You can download the full example code in `mlx-examples`_. Assuming, the existence of ``weights.pth`` and ``tokenizer.model`` in the current working directory we can play around with our inference script as follows (the timings are representative of an M1 Ultra and the 7B parameter Llama model): .. code-block:: bash $ python convert.py weights.pth llama-7B.mlx.npz $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely' [INFO] Loading model from disk: 5.247 s Press enter to start generation ------ , having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, ------ [INFO] Prompt processing: 0.437 s [INFO] Full generation: 4.330 s We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds of those are spent processing the prompt. This amounts to a little over **39 ms per token**. By running with a much bigger prompt we can see that the per token generation time as well as the prompt processing time remains almost constant. .. code-block:: bash $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not' [INFO] Loading model from disk: 5.247 s Press enter to start generation ------ take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not ------ [INFO] Prompt processing: 0.579 s [INFO] Full generation: 4.690 s $ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not' [INFO] Loading model from disk: 5.628 s Press enter to start generation ------ take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “ ------ [INFO] Prompt processing: 0.633 s [INFO] Full generation: 21.475 s Scripts ------- .. admonition:: Download the code The full example code is available in `mlx-examples`_. .. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama .. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021. Roformer: Enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864. .. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization. Advances in Neural Information Processing Systems, 32. .. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint arXiv:2002.05202. ================================================ FILE: docs/src/examples/mlp.rst ================================================ .. _mlp: Multi-Layer Perceptron ---------------------- In this example we'll learn to use ``mlx.nn`` by implementing a simple multi-layer perceptron to classify MNIST. As a first step import the MLX packages we need: .. code-block:: python import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np The model is defined as the ``MLP`` class which inherits from :class:`mlx.nn.Module`. We follow the standard idiom to make a new module: 1. Define an ``__init__`` where the parameters and/or submodules are setup. See the :ref:`Module class docs` for more information on how :class:`mlx.nn.Module` registers parameters. 2. Define a ``__call__`` where the computation is implemented. .. code-block:: python class MLP(nn.Module): def __init__( self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int ): super().__init__() layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] self.layers = [ nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) ] def __call__(self, x): for l in self.layers[:-1]: x = mx.maximum(l(x), 0.0) return self.layers[-1](x) We define the loss function which takes the mean of the per-example cross entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some commonly used loss functions. .. code-block:: python def loss_fn(model, X, y): return mx.mean(nn.losses.cross_entropy(model(X), y)) We also need a function to compute the accuracy of the model on the validation set: .. code-block:: python def eval_fn(model, X, y): return mx.mean(mx.argmax(model(X), axis=1) == y) Next, setup the problem parameters and load the data. To load the data, you need our `mnist data loader `_, which we will import as ``mnist``. .. code-block:: python num_layers = 2 hidden_dim = 32 num_classes = 10 batch_size = 256 num_epochs = 10 learning_rate = 1e-1 # Load the data import mnist train_images, train_labels, test_images, test_labels = map( mx.array, mnist.mnist() ) Since we're using SGD, we need an iterator which shuffles and constructs minibatches of examples in the training set: .. code-block:: python def batch_iterate(batch_size, X, y): perm = mx.array(np.random.permutation(y.size)) for s in range(0, y.size, batch_size): ids = perm[s : s + batch_size] yield X[ids], y[ids] Finally, we put it all together by instantiating the model, the :class:`mlx.optimizers.SGD` optimizer, and running the training loop: .. code-block:: python # Load the model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) mx.eval(model.parameters()) # Get a function which gives the loss and gradient of the # loss with respect to the model's trainable parameters loss_and_grad_fn = nn.value_and_grad(model, loss_fn) # Instantiate the optimizer optimizer = optim.SGD(learning_rate=learning_rate) for e in range(num_epochs): for X, y in batch_iterate(batch_size, train_images, train_labels): loss, grads = loss_and_grad_fn(model, X, y) # Update the optimizer state and model parameters # in a single call optimizer.update(model, grads) # Force a graph evaluation mx.eval(model.parameters(), optimizer.state) accuracy = eval_fn(model, test_images, test_labels) print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}") .. note:: The :func:`mlx.nn.value_and_grad` function is a convenience function to get the gradient of a loss with respect to the trainable parameters of a model. This should not be confused with :func:`mlx.core.value_and_grad`. The model should train to a decent accuracy (about 95%) after just a few passes over the training set. The `full example `_ is available in the MLX GitHub repo. ================================================ FILE: docs/src/examples/tensor_parallelism.rst ================================================ .. _tensor_parallelism: Tensor Parallelism ================== In this example, we will explore how tensor parallelism (TP) works in MLX. We will start with an overview of the distributed layers in ``mlx.nn`` and then show how to do tensor parallelism Llama-style transformer models. Sharded Layers -------------- :class:`AllToShardedLinear ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This layer replicates a common input and shards the weight matrix along the output dimension across all devices in the :class:`mlx.core.distributed.Group`. The layer produces a sharded output. For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix along the output dimension across the two devices, where each device receives the full input and computes a partial output. .. raw:: html
column-wise tensor parallelism
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice `. :class:`QuantizedAllToShardedLinear ` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be included in any gradient computation. :class:`ShardedToAllLinear ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This layer expects inputs that are sharded along the feature dimension and shards the weight matrix along the input dimension across all devices in the :class:`mlx.core.distributed.Group`. The layer automatically aggregates the results using :class:`mlx.core.distributed.all_sum`, so all devices in the group will have the same result. For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix along the input dimension across the two devices. Each device computes a ``(4,2)`` output, which is then aggregated with all other device outputs to get layer output. .. raw:: html
row-wise tensor parallelism
This layer does not automatically shard the inputs along the feature dimension for you. It is necessary to create a "partial" input structure to feed into the layer. This is an intended and :ref:`useful design choice `. :class:`QuantizedShardedToAllLinear ` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be included in any gradient computation. Shard Utility Functions ----------------------- :func:`shard_linear ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified. :func:`shard_inplace ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear `, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed. .. _useful_design_choices: Useful Design Choices --------------------- The design choices above regarding when operations are done automatically are intentional and make model training and inference easier. All-to-sharded and sharded-to-all layers naturally go together because the output of the former layer is exactly the input needed needed for the latter. This removes the need for an intermediate gather step between the layers, reducing communication overhead. This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs automatically. It is so that they can be placed in successive order and work together easily. We can demonstrate this through a simple model using our two types of distributed layers. .. code-block:: python x = ... # some (4, 2) model input: batch size 4, feature size 2 l1 = nn.AllToShardedLinear(2, 2, bias=False) # initialize the layer l1_out = l1(x) # (4, 1) output l2 = nn.ShardedToAllLinear(2, 2, bias=False) l2_out = l2(l1_out) # (4, 2) output .. raw:: html
two layer tensor parallelism

A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.

LLM Inference with Tensor Parallelism ------------------------------------- We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices. To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference ` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_. Our first edit is to initialize the distributed communication group and get the current process rank: .. code-block:: python world = mx.distributed.init() rank = world.rank() Next, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism: .. raw:: html
llama transformer example
This architecture has two natural places where tensor parallelism can be applied: the attention block and the FFN block. Both follow the same pattern: multiple parallel linear layers operating on the same input, followed by a single output linear layer. In the attention block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections become all-to-sharded layers, and the down projection becomes an sharded-to-all layer. The intermediate operations between the linear layers (RoPE, softmax, scaled dot-product attention in the attention block, and element-wise multiplication in the FFN block) do not impede the use of our TP paradigm. These operations are either: - **Element-wise operations** (RoPE, element-wise multiplication): These operate independently on each element or position, preserving the sharding pattern without requiring cross-device communication. - **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent sharded-to-all layer. To implement sharding in our Llama inference, we use :func:`shard_linear ` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace ` and implementing the steps manually in the :code:`__call__` function. The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to all-to-sharded layers, while the output projection is converted to a sharded-to-all layer. The number of heads are also adjusted to account for the sharding: .. code-block:: python # ... in Attention class def shard(self, group: mx.distributed.Group): self.n_heads = self.n_heads // group.size() self.n_kv_heads = self.n_kv_heads // group.size() self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group) self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group) self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group) self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group) Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to all-to-sharded layers, and the down projection (w2) to a sharded-to-all layer: .. code-block:: python # ... in FeedForward class def shard(self, group: mx.distributed.Group): self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group) self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group) self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group) Finally, in our :code:`load_model` function, we need to apply our sharding functions to all transformer layers when using multiple devices: .. code-block:: python # ... in load_model function if world.size() > 1: # convert Linear layers in Transformer/FFN to appropriate Sharded Layers for layer in model.layers: layer.attention.shard(group=world) layer.feed_forward.shard(group=world) This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`. .. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama ================================================ FILE: docs/src/index.rst ================================================ MLX === MLX is a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon, brought to you by Apple machine learning research. The Python API closely follows NumPy with a few exceptions. MLX also has a fully featured C++ API which closely follows the Python API. The main differences between MLX and NumPy are: - **Composable function transformations**: MLX has composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization. - **Lazy computation**: Computations in MLX are lazy. Arrays are only materialized when needed. - **Multi-device**: Operations can run on any of the supported devices (CPU, GPU, ...) The design of MLX is inspired by frameworks like `PyTorch `_, `Jax `_, and `ArrayFire `_. A notable difference from these frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU. .. toctree:: :caption: Install :maxdepth: 1 install .. toctree:: :caption: Usage :maxdepth: 1 usage/quick_start usage/lazy_evaluation usage/unified_memory usage/indexing usage/saving_and_loading usage/function_transforms usage/compile usage/numpy usage/distributed usage/using_streams usage/export .. toctree:: :caption: Examples :maxdepth: 1 examples/linear_regression examples/mlp examples/llama-inference examples/data_parallelism examples/tensor_parallelism .. toctree:: :caption: Python API Reference :maxdepth: 1 python/array python/data_types python/devices_and_streams python/export python/ops python/random python/transforms python/fast python/fft python/linalg python/metal python/cuda python/memory_management python/nn python/optimizers python/distributed python/tree_utils .. toctree:: :caption: C++ API Reference :maxdepth: 1 cpp/ops .. toctree:: :caption: Further Reading :maxdepth: 1 dev/extensions dev/metal_debugger dev/metal_logging dev/custom_metal_kernels dev/mlx_in_cpp ================================================ FILE: docs/src/install.rst ================================================ .. _build_and_install: Build and Install ================= Python Installation ------------------- MLX is available on PyPI. All you have to do to use MLX with your own Apple silicon computer is .. code-block:: shell pip install mlx To install from PyPI your system must meet the following requirements: - Using `Apple silicon `_ - Using a native Python >= 3.10 - macOS >= 14.0 .. note:: MLX is only available on devices running macOS >= 14.0 and higher. CUDA ^^^^ MLX has a CUDA backend which you can install with: .. code-block:: shell pip install mlx[cuda12] To install the CUDA package from PyPi your system must meet the following requirements: - Nvidia architecture >= SM 7.5 - Nvidia driver >= 550.54.14 - CUDA toolkit >= 12.0 - Linux distribution with glibc >= 2.35 - Python >= 3.10 For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires an Nvidia driver >= 580 or an appropriate CUDA compatibility package. CPU-only (Linux) ^^^^^^^^^^^^^^^^ For a CPU-only version of MLX that runs on Linux use: .. code-block:: shell pip install mlx[cpu] To install the CPU-only package from PyPi your system must meet the following requirements: - Linux distribution with glibc >= 2.35 - Python >= 3.10 Troubleshooting ^^^^^^^^^^^^^^^ *My OS and Python versions are in the required range but pip still does not find a matching distribution.* Probably you are using a non-native Python. The output of .. code-block:: shell python -c "import platform; print(platform.processor())" should be ``arm``. If it is ``i386`` (and you have M series machine) then you are using a non-native Python. Switch your Python to a native Python. A good way to do this is with `Conda `_. Build from source ----------------- Build Requirements ^^^^^^^^^^^^^^^^^^ - ``libblas-dev``, ``liblapack-dev``, and ``liblapacke-dev`` (Linux) - A C++ compiler with C++20 support (e.g. Clang >= 15.0) - `cmake `_ -- version 3.25 or later, and ``make`` - Xcode >= 15.0 and macOS SDK >= 14.0 .. note:: Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section ` below. Python API ^^^^^^^^^^ .. _python install: To build and install the MLX python library from source, first, clone MLX from `its GitHub repo `_: .. code-block:: shell git clone git@github.com:ml-explore/mlx.git mlx && cd mlx Then simply build and install MLX using pip: .. code-block:: shell pip install . For developing, install the package with development dependencies, and use an editable install: .. code-block:: shell pip install -e ".[dev]" Once the development dependencies are installed, you can build faster with: .. code-block:: shell python setup.py build_ext --inplace Run the tests with: .. code-block:: shell python -m unittest discover python/tests C++ API ^^^^^^^ .. _cpp install: Currently, MLX must be built and installed from source. Similarly to the python library, to build and install the MLX C++ library start by cloning MLX from `its GitHub repo `_: .. code-block:: shell git clone git@github.com:ml-explore/mlx.git mlx && cd mlx Create a build directory and run CMake and make: .. code-block:: shell mkdir -p build && cd build cmake .. && make -j Run tests with: .. code-block:: shell make test Install with: .. code-block:: shell make install Note that the built ``mlx.metallib`` file should be either at the same directory as the executable statically linked to ``libmlx.a`` or the preprocessor constant ``METAL_PATH`` should be defined at build time and it should point to the path to the built metal library. .. list-table:: Build Options :widths: 25 8 :header-rows: 1 * - Option - Default * - MLX_BUILD_TESTS - ON * - MLX_BUILD_EXAMPLES - OFF * - MLX_BUILD_BENCHMARKS - OFF * - MLX_BUILD_METAL - ON * - MLX_BUILD_CPU - ON * - MLX_BUILD_PYTHON_BINDINGS - OFF * - MLX_METAL_DEBUG - OFF * - MLX_BUILD_SAFETENSORS - ON * - MLX_BUILD_GGUF - ON * - MLX_METAL_JIT - OFF .. note:: If you have multiple Xcode installations and wish to use a specific one while building, you can do so by adding the following environment variable before building .. code-block:: shell export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/" Further, you can use the following command to find out which macOS SDK will be used .. code-block:: shell xcrun -sdk macosx --show-sdk-version Binary Size Minimization ~~~~~~~~~~~~~~~~~~~~~~~~ To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel`` and ``BUILD_SHARED_LIBS=ON``. The MLX CMake build has several additional options to make smaller binaries. For example, if you don't need the CPU backend or support for safetensors and GGUF, you can do: .. code-block:: shell cmake .. \ -DCMAKE_BUILD_TYPE=MinSizeRel \ -DBUILD_SHARED_LIBS=ON \ -DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_SAFETENSORS=OFF \ -DMLX_BUILD_GGUF=OFF \ -DMLX_METAL_JIT=ON THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which contains pre-built GPU kernels. This substantially reduces the size of the Metal library by run-time compiling kernels the first time they are used in MLX on a given machine. Note run-time compilation incurs a cold-start cost which can be anwywhere from a few hundred millisecond to a few seconds depending on the application. Once a kernel is compiled, it will be cached by the system. The Metal kernel cache persists across reboots. Linux ^^^^^ To build from source on Linux (CPU only), install the BLAS and LAPACK headers. For example on Ubuntu, run the following: .. code-block:: shell apt-get update -y apt-get install libblas-dev liblapack-dev liblapacke-dev -y From here follow the instructions to install either the :ref:`Python ` or :ref:`C++ ` APIs. CUDA ^^^^ To build from source on Linux with CUDA, install the BLAS and LAPACK headers and the CUDA toolkit. For example on Ubuntu, run the following: .. code-block:: shell wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb apt-get update -y apt-get -y install cuda-toolkit-12-9 apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y When building either the Python or C++ APIs make sure to pass the cmake flag ``MLX_BUILD_CUDA=ON``. For example, to build the Python API run: .. code-block:: shell CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" To build the C++ package run: .. code-block:: shell mkdir -p build && cd build cmake .. -DMLX_BUILD_CUDA=ON && make -j Troubleshooting ^^^^^^^^^^^^^^^ Metal not found ~~~~~~~~~~~~~~~ You see the following error when you try to build: .. code-block:: shell error: unable to find utility "metal", not a developer tool or in PATH To fix this, first make sure you have Xcode installed: .. code-block:: shell xcode-select --install Then set the active developer directory: .. code-block:: shell sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer x86 Shell ~~~~~~~~~ .. _build shell: If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via Rosetta instead of natively. To fix this, find the application in Finder (``/Applications`` for iTerm, ``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”. Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your terminal. Verify the terminal is now running natively the following command: .. code-block:: shell $ uname -p arm Also check that cmake is using the correct architecture: .. code-block:: shell $ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR CMAKE_HOST_SYSTEM_PROCESSOR "arm64" If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"`` but the build errors out with "Building for x86_64 on macOS is not supported." wipe your build cache with ``rm -rf build/`` and try again. ================================================ FILE: docs/src/python/array.rst ================================================ .. _array: Array ===== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary array array.astype array.at array.item array.tolist array.dtype array.itemsize array.nbytes array.ndim array.shape array.size array.real array.imag array.abs array.all array.any array.argmax array.argmin array.conj array.cos array.cummax array.cummin array.cumprod array.cumsum array.diag array.diagonal array.exp array.flatten array.log array.log10 array.log1p array.log2 array.logcumsumexp array.logsumexp array.max array.mean array.min array.moveaxis array.prod array.reciprocal array.reshape array.round array.rsqrt array.sin array.split array.sqrt array.square array.squeeze array.std array.sum array.swapaxes array.transpose array.T array.var array.view ================================================ FILE: docs/src/python/cuda.rst ================================================ CUDA ===== .. currentmodule:: mlx.core.cuda .. autosummary:: :toctree: _autosummary is_available ================================================ FILE: docs/src/python/data_types.rst ================================================ .. _data_types: Data Types ========== .. currentmodule:: mlx.core The default floating point type is ``float32`` and the default integer type is ``int32``. The table below shows supported values for :obj:`Dtype`. .. list-table:: Supported Data Types :widths: 5 3 20 :header-rows: 1 * - Type - Bytes - Description * - ``bool_`` - 1 - Boolean (``True``, ``False``) data type * - ``uint8`` - 1 - 8-bit unsigned integer * - ``uint16`` - 2 - 16-bit unsigned integer * - ``uint32`` - 4 - 32-bit unsigned integer * - ``uint64`` - 8 - 64-bit unsigned integer * - ``int8`` - 1 - 8-bit signed integer * - ``int16`` - 2 - 16-bit signed integer * - ``int32`` - 4 - 32-bit signed integer * - ``int64`` - 8 - 64-bit signed integer * - ``bfloat16`` - 2 - 16-bit brain float (e8, m7) * - ``float16`` - 2 - 16-bit IEEE float (e5, m10) * - ``float32`` - 4 - 32-bit float * - ``float64`` - 8 - 64-bit double * - ``complex64`` - 8 - 64-bit complex float .. note:: Arrays with type ``float64`` only work with CPU operations. Using ``float64`` arrays on the GPU will result in an exception. Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object documentation for more information. Use :func:`issubdtype` to determine if one ``dtype`` (or category) is a subtype of another category. .. autosummary:: :toctree: _autosummary Dtype DtypeCategory issubdtype finfo ================================================ FILE: docs/src/python/devices_and_streams.rst ================================================ .. _devices_and_streams: Devices and Streams =================== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary Device Stream default_device set_default_device default_stream new_stream set_default_stream stream synchronize device_count device_info ================================================ FILE: docs/src/python/distributed.rst ================================================ .. _distributed: .. currentmodule:: mlx.core.distributed Distributed Communication ========================== MLX provides a distributed communication package using MPI. The MPI library is loaded at runtime; if MPI is available then distributed communication is also made available. .. autosummary:: :toctree: _autosummary Group is_available init all_sum all_gather send recv recv_like ================================================ FILE: docs/src/python/export.rst ================================================ .. _export: Export Functions ================ .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary export_function import_function exporter export_to_dot ================================================ FILE: docs/src/python/fast.rst ================================================ .. _fast: Fast ==== .. currentmodule:: mlx.core.fast .. autosummary:: :toctree: _autosummary rms_norm layer_norm rope scaled_dot_product_attention metal_kernel cuda_kernel ================================================ FILE: docs/src/python/fft.rst ================================================ .. _fft: FFT === .. currentmodule:: mlx.core.fft .. autosummary:: :toctree: _autosummary fft ifft fft2 ifft2 fftn ifftn rfft irfft rfft2 irfft2 rfftn irfftn fftshift ifftshift ================================================ FILE: docs/src/python/linalg.rst ================================================ .. _linalg: Linear Algebra ============== .. currentmodule:: mlx.core.linalg .. autosummary:: :toctree: _autosummary inv tri_inv norm cholesky cholesky_inv cross qr svd eigvals eig eigvalsh eigh lu lu_factor pinv solve solve_triangular ================================================ FILE: docs/src/python/memory_management.rst ================================================ Memory Management ================= .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary get_active_memory get_peak_memory reset_peak_memory get_cache_memory set_memory_limit set_cache_limit set_wired_limit clear_cache ================================================ FILE: docs/src/python/metal.rst ================================================ Metal ===== .. currentmodule:: mlx.core.metal .. autosummary:: :toctree: _autosummary is_available device_info start_capture stop_capture ================================================ FILE: docs/src/python/nn/distributed.rst ================================================ .. _nn_distributed: Distributed ----------- Helper Routines ^^^^^^^^^^^^^^^ The :code:`mlx.nn.layers.distributed` package contains helpful routines to create sharded layers from existing :class:`Modules `. .. currentmodule:: mlx.nn.layers.distributed .. autosummary:: :toctree: _autosummary shard_linear shard_inplace Layers ^^^^^^ .. currentmodule:: mlx.nn .. autosummary:: :toctree: _autosummary :template: nn-module-template.rst AllToShardedLinear ShardedToAllLinear QuantizedAllToShardedLinear QuantizedShardedToAllLinear ================================================ FILE: docs/src/python/nn/functions.rst ================================================ .. _nn_functions: .. currentmodule:: mlx.nn Functions --------- Layers without parameters (e.g. activation functions) are also provided as simple functions. .. autosummary:: :toctree: _autosummary_functions :template: nn-module-template.rst elu celu gelu gelu_approx gelu_fast_approx glu hard_shrink hard_tanh hardswish leaky_relu log_sigmoid log_softmax mish prelu relu relu2 relu6 selu sigmoid silu softmax softmin softplus softshrink step tanh ================================================ FILE: docs/src/python/nn/init.rst ================================================ .. _init: .. currentmodule:: mlx.nn.init Initializers ------------ The ``mlx.nn.init`` package contains commonly used initializers for neural network parameters. Initializers return a function which can be applied to any input :obj:`mlx.core.array` to produce an initialized output. For example: .. code:: python import mlx.core as mx import mlx.nn as nn init_fn = nn.init.uniform() # Produces a [2, 2] uniform matrix param = init_fn(mx.zeros((2, 2))) To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform distribution, you can do: .. code:: python import mlx.nn as nn model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) init_fn = nn.init.uniform(low=-0.1, high=0.1) model.apply(init_fn) .. autosummary:: :toctree: _autosummary constant normal uniform identity glorot_normal glorot_uniform he_normal he_uniform ================================================ FILE: docs/src/python/nn/layers.rst ================================================ .. _layers: .. currentmodule:: mlx.nn Layers ------ .. autosummary:: :toctree: _autosummary :template: nn-module-template.rst ALiBi AllToShardedLinear AvgPool1d AvgPool2d AvgPool3d BatchNorm CELU Conv1d Conv2d Conv3d ConvTranspose1d ConvTranspose2d ConvTranspose3d Dropout Dropout2d Dropout3d Embedding ELU GELU GLU GroupNorm GRU HardShrink HardTanh Hardswish InstanceNorm LayerNorm LeakyReLU Linear LogSigmoid LogSoftmax LSTM MaxPool1d MaxPool2d MaxPool3d Mish MultiHeadAttention PReLU QuantizedAllToShardedLinear QuantizedEmbedding QuantizedLinear QuantizedShardedToAllLinear RMSNorm ReLU ReLU2 ReLU6 RNN RoPE SELU Sequential ShardedToAllLinear Sigmoid SiLU SinusoidalPositionalEncoding Softmin Softshrink Softsign Softmax Softplus Step Tanh Transformer Upsample ================================================ FILE: docs/src/python/nn/losses.rst ================================================ .. _losses: .. currentmodule:: mlx.nn.losses Loss Functions -------------- .. autosummary:: :toctree: _autosummary_functions :template: nn-module-template.rst binary_cross_entropy cosine_similarity_loss cross_entropy gaussian_nll_loss hinge_loss huber_loss kl_div_loss l1_loss log_cosh_loss margin_ranking_loss mse_loss nll_loss smooth_l1_loss triplet_loss ================================================ FILE: docs/src/python/nn/module.rst ================================================ Module ====== .. currentmodule:: mlx.nn .. autoclass:: Module .. rubric:: Attributes .. autosummary:: :toctree: _autosummary Module.training Module.state .. rubric:: Methods .. autosummary:: :toctree: _autosummary Module.apply Module.apply_to_modules Module.children Module.eval Module.filter_and_map Module.freeze Module.leaf_modules Module.load_weights Module.modules Module.named_modules Module.parameters Module.save_weights Module.set_dtype Module.train Module.trainable_parameters Module.unfreeze Module.update Module.update_modules ================================================ FILE: docs/src/python/nn.rst ================================================ .. _nn: .. currentmodule:: mlx.nn Neural Networks =============== Writing arbitrarily complex neural networks in MLX can be done using only :class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the user to write again and again the same simple neural network operations as well as handle all the parameter state and initialization manually and explicitly. The module :mod:`mlx.nn` solves this problem by providing an intuitive way of composing neural network layers, initializing their parameters, freezing them for finetuning and more. Quick Start with Neural Networks --------------------------------- .. code-block:: python import mlx.core as mx import mlx.nn as nn class MLP(nn.Module): def __init__(self, in_dims: int, out_dims: int): super().__init__() self.layers = [ nn.Linear(in_dims, 128), nn.Linear(128, 128), nn.Linear(128, out_dims), ] def __call__(self, x): for i, l in enumerate(self.layers): x = mx.maximum(x, 0) if i > 0 else x x = l(x) return x # The model is created with all its parameters but nothing is initialized # yet because MLX is lazily evaluated mlp = MLP(2, 10) # We can access its parameters by calling mlp.parameters() params = mlp.parameters() print(params["layers"][0]["weight"].shape) # Printing a parameter will cause it to be evaluated and thus initialized print(params["layers"][0]) # We can also force evaluate all parameters to initialize the model mx.eval(mlp.parameters()) # A simple loss function. # NOTE: It doesn't matter how it uses the mlp model. It currently captures # it from the local scope. It could be a positional argument or a # keyword argument. def l2_loss(x, y): y_hat = mlp(x) return (y_hat - y).square().mean() # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) .. _module_class: The Module Class ---------------- The workhorse of any neural network library is the :class:`Module` class. In MLX the :class:`Module` class is a container of :class:`mlx.core.array` or :class:`Module` instances. Its main function is to provide a way to recursively **access** and **update** its parameters and those of its submodules. Parameters ^^^^^^^^^^ A parameter of a module is any public member of type :class:`mlx.core.array` (its name should not start with ``_``). It can be arbitrarily nested in other :class:`Module` instances or lists and dictionaries. :meth:`Module.parameters` can be used to extract a nested dictionary with all the parameters of a module and its submodules. A :class:`Module` can also keep track of "frozen" parameters. See the :meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these trainable parameters. Updating the Parameters ^^^^^^^^^^^^^^^^^^^^^^^ MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is performed by :meth:`Module.update`. Inspecting Modules ^^^^^^^^^^^^^^^^^^ The simplest way to see the model architecture is to print it. Following along with the above example, you can print the ``MLP`` with: .. code-block:: python print(mlp) This will display: .. code-block:: shell MLP( (layers.0): Linear(input_dims=2, output_dims=128, bias=True) (layers.1): Linear(input_dims=128, output_dims=128, bias=True) (layers.2): Linear(input_dims=128, output_dims=10, bias=True) ) To get more detailed information on the arrays in a :class:`Module` you can use :func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of all the parameters in a :class:`Module` do: .. code-block:: python from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) As another example, you can count the number of parameters in a :class:`Module` with: .. code-block:: python from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) Value and Grad -------------- Using a :class:`Module` does not preclude using MLX's high order function transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, these function transformations assume pure functions, namely the parameters should be passed as an argument to the function being transformed. There is an easy pattern to achieve that with MLX modules .. code-block:: python model = ... def f(params, other_inputs): model.update(params) # <---- Necessary to make the model use the passed parameters return model(other_inputs) f(model.trainable_parameters(), mx.zeros((10,))) However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only computes the gradients with respect to the trainable parameters of the model. In detail: - it wraps the passed function with a function that calls :meth:`Module.update` to make sure the model is using the provided parameters. - it calls :meth:`mlx.core.value_and_grad` to transform the function into a function that also computes the gradients with respect to the passed parameters. - it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by :meth:`mlx.core.value_and_grad` .. autosummary:: :toctree: _autosummary value_and_grad quantize average_gradients fsdp_apply_gradients .. toctree:: nn/module nn/layers nn/functions nn/losses nn/init nn/distributed ================================================ FILE: docs/src/python/ops.rst ================================================ .. _ops: Operations ========== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary abs add addmm all allclose any arange arccos arccosh arcsin arcsinh arctan arctan2 arctanh argmax argmin argpartition argsort array_equal as_strided atleast_1d atleast_2d atleast_3d bitwise_and bitwise_invert bitwise_or bitwise_xor block_masked_mm broadcast_arrays broadcast_to ceil clip concatenate contiguous conj conjugate convolve conv1d conv2d conv3d conv_transpose1d conv_transpose2d conv_transpose3d conv_general cos cosh cummax cummin cumprod cumsum degrees dequantize diag diagonal divide divmod einsum einsum_path equal erf erfinv exp expm1 expand_dims eye flatten floor floor_divide full gather_mm gather_qmm greater greater_equal hadamard_transform identity imag inner isfinite isclose isinf isnan isneginf isposinf issubdtype kron left_shift less less_equal linspace load log log2 log10 log1p logaddexp logcumsumexp logical_not logical_and logical_or logsumexp matmul max maximum mean median meshgrid min minimum moveaxis multiply nan_to_num negative not_equal ones ones_like outer partition pad power prod put_along_axis quantize quantized_matmul radians real reciprocal remainder repeat reshape right_shift roll round rsqrt save savez savez_compressed save_gguf save_safetensors sigmoid sign sin sinh slice slice_update softmax sort split sqrt square squeeze stack std stop_gradient subtract sum swapaxes take take_along_axis tan tanh tensordot tile topk trace transpose tri tril triu unflatten var view where zeros zeros_like ================================================ FILE: docs/src/python/optimizers/common_optimizers.rst ================================================ .. _common_optimizers: Common Optimizers ================= .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary :template: optimizers-template.rst SGD RMSprop Adagrad Adafactor AdaDelta Adam AdamW Adamax Lion MultiOptimizer Muon ================================================ FILE: docs/src/python/optimizers/optimizer.rst ================================================ Optimizer ========= .. currentmodule:: mlx.optimizers .. autoclass:: Optimizer .. rubric:: Attributes .. autosummary:: :toctree: _autosummary Optimizer.state .. rubric:: Methods .. autosummary:: :toctree: _autosummary Optimizer.apply_gradients Optimizer.init Optimizer.update ================================================ FILE: docs/src/python/optimizers/schedulers.rst ================================================ .. _schedulers: Schedulers ========== .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary cosine_decay exponential_decay join_schedules linear_schedule step_decay ================================================ FILE: docs/src/python/optimizers.rst ================================================ .. _optimizers: .. currentmodule:: mlx.optimizers Optimizers ========== The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure :mod:`mlx.core` functions. A typical example involves calling :meth:`Optimizer.update` to update a model's parameters based on the loss gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the model's parameters and the **optimizer state**. .. code-block:: python # Create a model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) mx.eval(model.parameters()) # Create the gradient function and the optimizer loss_and_grad_fn = nn.value_and_grad(model, loss_fn) optimizer = optim.SGD(learning_rate=learning_rate) for e in range(num_epochs): for X, y in batch_iterate(batch_size, train_images, train_labels): loss, grads = loss_and_grad_fn(model, X, y) # Update the model with the gradients. So far no computation has happened. optimizer.update(model, grads) # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) Saving and Loading ------------------ To serialize an optimizer, save its state. To load an optimizer, load and set the saved state. Here's a simple example: .. code-block:: python import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten import mlx.optimizers as optim optimizer = optim.Adam(learning_rate=1e-2) # Perform some updates with the optimizer model = {"w" : mx.zeros((5, 5))} grads = {"w" : mx.ones((5, 5))} optimizer.update(model, grads) # Save the state state = tree_flatten(optimizer.state, destination={}) mx.save_safetensors("optimizer.safetensors", state) # Later on, for example when loading from a checkpoint, # recreate the optimizer and load the state optimizer = optim.Adam(learning_rate=1e-2) state = tree_unflatten(mx.load("optimizer.safetensors")) optimizer.state = state Note, not every optimizer configuation parameter is saved in the state. For example, for Adam the learning rate is saved but the ``betas`` and ``eps`` parameters are not. A good rule of thumb is if the parameter can be scheduled then it will be included in the optimizer state. .. toctree:: optimizers/optimizer optimizers/common_optimizers optimizers/schedulers .. autosummary:: :toctree: _autosummary clip_grad_norm ================================================ FILE: docs/src/python/random.rst ================================================ .. _random: Random ====== Random sampling functions in MLX use an implicit global PRNG state by default. However, all function take an optional ``key`` keyword argument for when more fine-grained control or explicit state management is needed. For example, you can generate random numbers with: .. code-block:: python for _ in range(3): print(mx.random.uniform()) which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key: .. code-block:: python key = mx.random.key(0) for _ in range(3): print(mx.random.uniform(key=key)) which will yield the same pseudo random number at each iteration. Following `JAX's PRNG design `_ we use a splittable version of Threefry, which is a counter-based PRNG. .. currentmodule:: mlx.core.random .. autosummary:: :toctree: _autosummary bernoulli categorical gumbel key normal multivariate_normal randint seed split truncated_normal uniform laplace permutation ================================================ FILE: docs/src/python/transforms.rst ================================================ .. _transforms: Transforms ========== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary eval async_eval compile checkpoint custom_function disable_compile enable_compile grad value_and_grad jvp vjp vmap ================================================ FILE: docs/src/python/tree_utils.rst ================================================ .. _utils: Tree Utils ========== In MLX we consider a python tree to be an arbitrarily nested collection of dictionaries, lists and tuples without cycles. Functions in this module that return python trees will be using the default python ``dict``, ``list`` and ``tuple`` but they can usually process objects that inherit from any of these. .. note:: Dictionaries should have keys that are valid python identifiers. .. currentmodule:: mlx.utils .. autosummary:: :toctree: _autosummary tree_flatten tree_unflatten tree_map tree_map_with_path tree_reduce ================================================ FILE: docs/src/usage/compile.rst ================================================ .. _compile: Compilation =========== .. currentmodule:: mlx.core MLX has a :func:`compile` function transformation which compiles computation graphs. Function compilation results in smaller graphs by merging common work and fusing certain operations. In many cases this can lead to big improvements in run-time and memory use. Getting started with :func:`compile` is simple, but there are some edge cases that are good to be aware of for more complex graphs and advanced usage. Basics of Compile ----------------- Let's start with a simple example: .. code-block:: python def fun(x, y): return mx.exp(-x) + y x = mx.array(1.0) y = mx.array(2.0) # Regular call, no compilation # Prints: array(2.36788, dtype=float32) print(fun(x, y)) # Compile the function compiled_fun = mx.compile(fun) # Prints: array(2.36788, dtype=float32) print(compiled_fun(x, y)) The output of both the regular function and the compiled function is the same up to numerical precision. The first time you call a compiled function, MLX will build the compute graph, optimize it, and generate and compile code. This can be relatively slow. However, MLX will cache compiled functions, so calling a compiled function multiple times will not initiate a new compilation. This means you should typically compile functions that you plan to use more than once. .. code-block:: python def fun(x, y): return mx.exp(-x) + y x = mx.array(1.0) y = mx.array(2.0) compiled_fun = mx.compile(fun) # Compiled here compiled_fun(x, y) # Not compiled again compiled_fun(x, y) # Not compiled again mx.compile(fun)(x, y) There are some important cases to be aware of that can cause a function to be recompiled: * Changing the shape or number of dimensions * Changing the type of any of the inputs * Changing the number of inputs to the function In certain cases only some of the compilation stack will be rerun (for example when changing the shapes) and in other cases the full compilation stack will be rerun (for example when changing the types). In general you should avoid compiling functions too frequently. Another idiom to watch out for is compiling functions which get created and destroyed frequently. This can happen, for example, when compiling an anonymous function in a loop: .. code-block:: python a = mx.array(1.0) # Don't do this, compiles lambda at each iteration for _ in range(5): mx.compile(lambda x: mx.exp(mx.abs(x)))(a) Example Speedup --------------- The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with Transformer-based models. The implementation involves several unary and binary element-wise operations: .. code-block:: python def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 If you use this function with small arrays, it will be overhead bound. If you use it with large arrays it will be memory bandwidth bound. However, all of the operations in the ``gelu`` are fusible into a single kernel with :func:`compile`. This can speedup both cases considerably. Let's compare the runtime of the regular function versus the compiled function. We'll use the following timing helper which does a warm up and handles synchronization: .. code-block:: python import time def timeit(fun, x): # warm up for _ in range(10): mx.eval(fun(x)) tic = time.perf_counter() for _ in range(100): mx.eval(fun(x)) toc = time.perf_counter() tpi = 1e3 * (toc - tic) / 100 print(f"Time per iteration {tpi:.3f} (ms)") Now make an array, and benchmark both functions: .. code-block:: python x = mx.random.uniform(shape=(32, 1000, 4096)) timeit(gelu, x) timeit(mx.compile(gelu), x) On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is five times faster. Debugging --------- When a compiled function is first called, it is traced with placeholder inputs. This means you can't evaluate arrays (for example to print their contents) inside compiled functions. .. code-block:: python @mx.compile def fun(x): z = -x print(z) # Crash return mx.exp(z) fun(mx.array(5.0)) For debugging, inspecting arrays can be helpful. One way to do that is to globally disable compilation using the :func:`disable_compile` function or ``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though ``fun`` is compiled: .. code-block:: python @mx.compile def fun(x): z = -x print(z) # Okay return mx.exp(z) mx.disable_compile() fun(mx.array(5.0)) Pure Functions -------------- Compiled functions are intended to be *pure*; that is they should not have side effects. For example: .. code-block:: python state = [] @mx.compile def fun(x, y): z = x + y state.append(z) return mx.exp(z) fun(mx.array(1.0), mx.array(2.0)) # Crash! print(state) After the first call of ``fun``, the ``state`` list will hold a placeholder array. The placeholder does not have any data; it is only used to build the computation graph. Printing such an array results in a crash. You have two options to deal with this. The first option is to simply return ``state`` as an output: .. code-block:: python state = [] @mx.compile def fun(x, y): z = x + y state.append(z) return mx.exp(z), state _, state = fun(mx.array(1.0), mx.array(2.0)) # Prints [array(3, dtype=float32)] print(state) In some cases returning updated state can be pretty inconvenient. Hence, :func:`compile` has a parameter to capture implicit outputs: .. code-block:: python from functools import partial state = [] # Tell compile to capture state as an output @partial(mx.compile, outputs=state) def fun(x, y): z = x + y state.append(z) return mx.exp(z) fun(mx.array(1.0), mx.array(2.0)) # Prints [array(3, dtype=float32)] print(state) This is particularly useful for compiling a function which includes an update to a container of arrays, as is commonly done when training the parameters of a :class:`mlx.nn.Module`. Compiled functions will also treat any inputs not in the parameter list as constants. For example: .. code-block:: python state = [mx.array(1.0)] @mx.compile def fun(x): return x + state[0] # Prints array(2, dtype=float32) print(fun(mx.array(1.0))) # Update state state[0] = mx.array(5.0) # Still prints array(2, dtype=float32) print(fun(mx.array(1.0))) In order to have the change of state reflected in the outputs of ``fun`` you again have two options. The first option is to simply pass ``state`` as input to the function. .. code-block:: python state = [mx.array(1.0)] @mx.compile def fun(x, state): return x + state[0] # Prints array(2, dtype=float32) print(fun(mx.array(1.0), state)) # Update state state[0] = mx.array(5.0) # Prints array(6, dtype=float32) print(fun(mx.array(1.0), state)) In some cases this can be pretty inconvenient. Hence, :func:`compile` also has a parameter to capture implicit inputs: .. code-block:: python from functools import partial state = [mx.array(1.0)] # Tell compile to capture state as an input @partial(mx.compile, inputs=state) def fun(x): return x + state[0] # Prints array(2, dtype=float32) print(fun(mx.array(1.0))) # Update state state[0] = mx.array(5.0) # Prints array(6, dtype=float32) print(fun(mx.array(1.0))) Compiling Training Graphs ------------------------- This section will step through how to use :func:`compile` with a simple example of a common setup: training a model with :obj:`mlx.nn.Module` using an :obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the full forward, backward, and update with :func:`compile`. To start, here is the simple example without any compilation: .. code-block:: python import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim # 4 examples with 10 features each x = mx.random.uniform(shape=(4, 10)) # 0, 1 targets y = mx.array([0, 1, 0, 1]) # Simple linear model model = nn.Linear(10, 1) # SGD with momentum optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) def loss_fn(model, x, y): logits = model(x).squeeze() return nn.losses.binary_cross_entropy(logits, y) loss_and_grad_fn = nn.value_and_grad(model, loss_fn) # Perform 10 steps of gradient descent for it in range(10): loss, grads = loss_and_grad_fn(model, x, y) optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) To compile the update we can put it all in a function and compile it with the appropriate input and output captures. Here's the same example but compiled: .. code-block:: python import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from functools import partial # 4 examples with 10 features each x = mx.random.uniform(shape=(4, 10)) # 0, 1 targets y = mx.array([0, 1, 0, 1]) # Simple linear model model = nn.Linear(10, 1) # SGD with momentum optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) def loss_fn(model, x, y): logits = model(x).squeeze() return nn.losses.binary_cross_entropy(logits, y) # The state that will be captured as input and output state = [model.state, optimizer.state] @partial(mx.compile, inputs=state, outputs=state) def step(x, y): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn(model, x, y) optimizer.update(model, grads) return loss # Perform 10 steps of gradient descent for it in range(10): loss = step(x, y) # Evaluate the model and optimizer state mx.eval(state) print(loss) .. note:: If you are using a module which performs random sampling such as :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the ``state`` captured by :func:`compile`, i.e. ``state = [model.state, optimizer.state, mx.random.state]``. .. note:: For more examples of compiling full training graphs checkout the `MLX Examples `_ GitHub repo. Transformations with Compile ---------------------------- In MLX function transformations are composable. You can apply any function transformation to the output of any other function transformation. For more on this, see the documentation on :ref:`function transforms `. Compiling transformed functions works just as expected: .. code-block:: python grad_fn = mx.grad(mx.exp) compiled_grad_fn = mx.compile(grad_fn) # Prints: array(2.71828, dtype=float32) print(grad_fn(mx.array(1.0))) # Also prints: array(2.71828, dtype=float32) print(compiled_grad_fn(mx.array(1.0))) .. note:: In order to compile as much as possible, a transformation of a compiled function will not by default be compiled. To compile the transformed function simply pass it through :func:`compile`. You can also compile functions which themselves call compiled functions. A good practice is to compile the outer most function to give :func:`compile` the most opportunity to optimize the computation graph: .. code-block:: python @mx.compile def inner(x): return mx.exp(-mx.abs(x)) def outer(x): inner(inner(x)) # Compiling the outer function is good to do as it will likely # be faster even though the inner functions are compiled fun = mx.compile(outer) .. _shapeless_compile: Shapeless Compilation --------------------- When the shape of an input to a compiled function changes, the function is recompiled. You can compile a function once and run it on inputs with variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this case changes to the shapes of the inputs do not cause the function to be recompiled. .. code-block:: python def fun(x, y): return mx.abs(x + y) compiled_fun = mx.compile(fun, shapeless=True) x = mx.array(1.0) y = mx.array(-2.0) # Firt call compiles the function print(compiled_fun(x, y)) # Second call with different shapes # does not recompile the function x = mx.array([1.0, -6.0]) y = mx.array([-2.0, 3.0]) print(compiled_fun(x, y)) Use shapeless compilations carefully. Since compilation is not triggered when shapes change, any graphs which are conditional on the input shapes will not work as expected. Shape-dependent computations are common and sometimes subtle to detect. For example: .. code-block:: python def fun(x): return x.reshape(x.shape[0] * x.shape[1], -1) compiled_fun = mx.compile(fun, shapeless=True) x = mx.random.uniform(shape=(2, 3, 4)) out = compiled_fun(x) x = mx.random.uniform(shape=(5, 5, 3)) # Error, can't reshape (5, 5, 3) to (6, -1) out = compiled_fun(x) The second call to the ``compiled_fun`` fails because of the call to :func:`reshape` which uses the static shape of ``x`` in the first call. We can fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``: .. code-block:: python def fun(x): return x.flatten(0, 1) compiled_fun = mx.compile(fun, shapeless=True) x = mx.random.uniform(shape=(2, 3, 4)) out = compiled_fun(x) x = mx.random.uniform(shape=(5, 5, 3)) # Ok out = compiled_fun(x) ================================================ FILE: docs/src/usage/distributed.rst ================================================ .. _usage_distributed: Distributed Communication ========================= .. currentmodule:: mlx.core.distributed MLX supports distributed communication operations that allow the computational cost of training or inference to be shared across many physical machines. At the moment we support several different communication backends introduced below. .. list-table:: :widths: 20 80 :header-rows: 1 * - Backend - Description * - :ref:`MPI ` - A full featured and mature distributed communications library. * - :ref:`RING ` - Ring all reduce and all gather over TCP sockets. Always available and usually faster than MPI. * - :ref:`JACCL ` - Low latency communication with RDMA over thunderbolt. Necessary for things like tensor parallelism. * - :ref:`NCCL ` - The backend of choice for CUDA environments. The list of all currently supported operations and their documentation can be seen in the :ref:`API docs`. Getting Started --------------- A distributed program in MLX is as simple as: .. code:: python import mlx.core as mx world = mx.distributed.init() x = mx.distributed.all_sum(mx.ones(10)) print(world.rank(), x) The program above sums the array ``mx.ones(10)`` across all distributed processes. However, when this script is run with ``python`` only one process is launched and no distributed communication takes place. Namely, all operations in ``mx.distributed`` are noops when the distributed group has a size of one. This property allows us to avoid code that checks if we are in a distributed setting similar to the one below: .. code:: python import mlx.core as mx x = ... world = mx.distributed.init() # No need for the check we can simply do x = mx.distributed.all_sum(x) if world.size() > 1: x = mx.distributed.all_sum(x) Running Distributed Programs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ MLX provides ``mlx.launch`` a helper script to launch distributed programs. Continuing with our initial example we can run it on localhost with 4 processes using .. code:: shell $ mlx.launch -n 4 my_script.py 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) We can also run it on some remote hosts by providing their IPs (provided that the script exists on all hosts and they are reachable by ssh) .. code:: shell $ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) Consult the dedicated :doc:`usage guide` for more information on using ``mlx.launch``. Selecting Backend ^^^^^^^^^^^^^^^^^ You can select the backend you want to use when calling :func:`init` by passing one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all available backends. If they all fail then a singleton group is created. .. note:: After a distributed backend is successfully initialized :func:`init` will return **the same backend** if called without arguments or with backend set to ``any``. The following examples aim to clarify the backend initialization logic in MLX: .. code:: python # Case 1: Initialize MPI regardless if it was possible to initialize the ring backend world = mx.distributed.init(backend="mpi") world2 = mx.distributed.init() # subsequent calls return the MPI backend! # Case 2: Initialize any backend world = mx.distributed.init(backend="any") # equivalent to no arguments world2 = mx.distributed.init() # same as above # Case 3: Initialize both backends at the same time world_mpi = mx.distributed.init(backend="mpi") world_ring = mx.distributed.init(backend="ring") world_any = mx.distributed.init() # same as MPI because it was initialized first! Distributed Program Examples ---------------------------- - :ref:`Data Parallelism ` - :ref:`Tensor Parallelism ` .. _ring_section: Getting Started with Ring ------------------------- The ring backend does not depend on any third party library so it is always available. It uses TCP sockets so the nodes need to be reachable via a network. As the name suggests the nodes are connected in a ring which means that rank 1 can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3 and so on and so forth. As a result :func:`send` and :func:`recv` with arbitrary sender and receiver are not supported in the ring backend. Defining a Ring ^^^^^^^^^^^^^^^ The easiest way to define and use a ring is via a JSON hostfile and the ``mlx.launch`` :doc:`helper script `. For each node one defines a hostname to ssh into to run commands on this node and one or more IPs that this node will listen to for connections. For example the hostfile below defines a 4 node ring. ``hostname1`` will be rank 0, ``hostname2`` rank 1 etc. .. code:: json [ {"ssh": "hostname1", "ips": ["123.123.123.1"]}, {"ssh": "hostname2", "ips": ["123.123.123.2"]}, {"ssh": "hostname3", "ips": ["123.123.123.3"]}, {"ssh": "hostname4", "ips": ["123.123.123.4"]} ] Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each node, run the script which will listen for connections in each of the provided IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a connection from ``123.123.123.4`` and so on and so forth. Thunderbolt Ring ^^^^^^^^^^^^^^^^ Although the ring backend can have benefits over MPI even for Ethernet, its main purpose is to use Thunderbolt rings for higher bandwidth communication. Setting up such thunderbolt rings can be done manually, but is a relatively tedious process. To simplify this, we provide the utility ``mlx.distributed_config``. To use ``mlx.distributed_config`` your computers need to be accessible by ssh via Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the utility as follows: .. code:: shell mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring By default the script will attempt to discover the thunderbolt ring and provide you with the commands to configure each node as well as the ``hostfile.json`` to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes then ``--auto-setup`` can be used to configure them automatically. If you want to go through the process manually, the steps are as follows: * Disable the thunderbolt bridge interface * For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces corresponding to that cable in nodes ``i`` and ``i + 1``. * Set up a unique subnetwork connecting the two nodes for the corresponding interfaces. For instance if the cable corresponds to ``en2`` on node ``i`` and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and ``192.168.0.2`` respectively to the two nodes. For more details you can see the commands prepared by the utility script. .. _jaccl_section: Getting Started with JACCL -------------------------- Starting from macOS 26.2, RDMA over thunderbolt is available and enables low-latency communication between Macs with thunderbolt 5. MLX provides the JACCL backend that uses this functionality to achieve communication latency an order of magnitude lower than the ring backend. .. note:: The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective Communication Library* and it is an obvious pun to Nvidia's NCCL but also tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt at Apple. Enabling RDMA ^^^^^^^^^^^^^ Until the feature matures, enabling RDMA over thunderbolt is slightly more involved and **cannot** be done remotely even with sudo. In fact, it has to be done in macOS recovery: 1. `Start your computer in recovery `_. 2. Open the Terminal by going to Utilities -> Terminal. 3. Run ``rdma_ctl enable``. 4. Reboot. To verify that you have successfully enabled Thunderbolt RDMA you can run ``ibv_devices`` which should produce something like the following for an M3 Ultra. .. code-block:: bash ~ % ibv_devices device node GUID ------ ---------------- rdma_en2 8096a9d9edbaac05 rdma_en3 8196a9d9edbaac05 rdma_en5 8396a9d9edbaac05 rdma_en4 8296a9d9edbaac05 rdma_en6 8496a9d9edbaac05 rdma_en7 8596a9d9edbaac05 Defining a Mesh ^^^^^^^^^^^^^^^ The JACCL backend supports only fully connected topologies. Namely, there needs to be a thunderbolt cable connecting all pairs of Macs directly. For example, in the following topology visualizations, the left one is valid because there is a connection from any node to any other node, while for the one on the right M3 Ultra 1 is not connected to M3 Ultra 2. .. raw:: html
M3 Ultra thunderbolt mesh

Fully connected mesh of four M3 Ultra.

M3 Ultra broken thunderbolt mesh

Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).

Similar to the ring backend, the easiest way to use JACCL with MLX is to write a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain - Hostnames to use for launching scripts via ssh - An IP for rank 0 that is reachable by all nodes - A list of rdma devices that connect each node to each other node The following JSON defines the valid 4-node mesh from the image above. .. code-block:: json [ { "ssh": "m3-ultra-1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-2", "ips": [], "rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"] }, { "ssh": "m3-ultra-3", "ips": [], "rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"] }, { "ssh": "m3-ultra-4", "ips": [], "rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null] } ] Even though TCP/IP is not used when communicating with Thunderbolt RDMA, disabling the thunderbolt bridge is still required as well as setting up isolated local networks for each thunderbolt connection. All of the above can be done instead via ``mlx.distributed_config``. This helper script will - ssh into each node - extract the thunderbolt connectivity - check for a valid mesh - provide the commands to configure each node (or run them if sudo is available) - generate the hostfile to be used with ``mlx.launch`` Putting It All Together ^^^^^^^^^^^^^^^^^^^^^^^^ For example launching a distributed MLX script that uses JACCL is fairly simple if the nodes are reachable via ssh and have password-less sudo. First, connect all the thunderbolt cables. Then we can verify the connections by using the ``mlx.distributed_config`` script to visualize them. .. code-block:: mlx.distributed_config --verbose \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ --over thunderbolt --dot | dot -Tpng | open -f -a Preview After making sure that everything looks right we can auto-configure the nodes and save the hostfile to ``m3-ultra-jaccl.json`` by running: .. code-block:: mlx.distributed_config --verbose \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ --over thunderbolt --backend jaccl \ --auto-setup --output m3-ultra-jaccl.json And now we are ready to run a distributed MLX script such as distributed inference of a gigantic model using MLX LM. .. code-block:: mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \ --env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important /path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-R1-0528-4bit .. note:: Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a different, faster way of synchronizing between the GPU and the CPU. It is not specific to the JACCL backend and can be used in all cases where the CPU and GPU need to collaborate for some computation and is pretty critical for low-latency communication since the communication is done by the CPU. .. _nccl_section: Getting Started with NCCL ------------------------- MLX on CUDA environments ships with the ability to talk to `NCCL `_ which is a high-performance collective communication library that supports both multi-gpu and multi-node setups. For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all it takes to run a distributed job is .. code-block:: mlx.launch -n 8 test.py # perfect for interactive scripts mlx.launch -n 8 python -m mlx_lm chat --model my-model You can also use ``mlx.launch`` to ssh to a remote node and launch a script with the same ease .. code-block:: mlx.launch --hosts my-cuda-node -n 8 test.py In many cases you may not want to use ``mlx.launch`` with the NCCL backend because the cluster scheduler will be the one launching the processes. You can :ref:`see which environment variables need to be defined ` in order for the MLX NCCL backend to be initialized correctly. .. _mpi_section: Getting Started with MPI ------------------------ MLX already comes with the ability to "talk" to `MPI `_ if it is installed on the machine. Launching distributed MLX programs that use MPI can be done with ``mpirun`` as expected. However, in the following examples we will be using ``mlx.launch --backend mpi`` which takes care of some nuisances such as setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared library. The simplest possible usage is the following which, assuming the minimal example in the beginning of this page, should result in: .. code:: shell $ mlx.launch --backend mpi -n 2 test.py 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) The above launches two processes on the same (local) machine and we can see both standard output streams. The processes send the array of 1s to each other and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would print 4 etc. Installing MPI ^^^^^^^^^^^^^^ MPI can be installed with Homebrew, pip, using the Anaconda package manager, or compiled from source. Most of our testing is done using ``openmpi`` installed with the Anaconda package manager as follows: .. code:: shell $ conda install conda-forge::openmpi Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld`` so that MLX can find it and load it at runtime. This can simply be achieved by passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is done automatically by ``mlx.launch``. Some environments use a non-standard library filename that can be specified using the ``MPI_LIBNAME`` environment variable. This is automatically taken care of by ``mlx.launch`` as well. .. code:: shell $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py $ # or simply $ mlx.launch -n 2 test.py Setting up Remote Hosts ^^^^^^^^^^^^^^^^^^^^^^^ MPI can automatically connect to remote hosts and set up the communication over the network if the remote hosts can be accessed via ssh. A good checklist to debug connectivity issues is the following: * ``ssh hostname`` works from all machines to all machines without asking for password or host confirmation * ``mpirun`` is accessible on all machines. * Ensure that the ``hostname`` used by MPI is the one that you have configured in the ``.ssh/config`` files on all machines. Tuning MPI All Reduce ^^^^^^^^^^^^^^^^^^^^^ .. note:: For faster all reduce consider using the ring backend either with Thunderbolt connections or over Ethernet. Configure MPI to use N tcp connections between each host to improve bandwidth by passing ``--mca btl_tcp_links N``. Force MPI to use the most performant network interface by setting ``--mca btl_tcp_if_include `` where ```` should be the interface you want to use. .. _no_mlx_launch: Distributed Without ``mlx.launch`` ---------------------------------- None of the implementations of the distributed backends require launching with ``mlx.launch``. The script simply connects to each host. Starts a process per rank and sets up the necessary environment variables before delegating to your MLX script. See the :doc:`dedicated documentation page ` for more details. For many use-cases this will be the easiest way to perform distributed computations in MLX. However, there may be reasons that you cannot or should not use ``mlx.launch``. A common such case is the use of a scheduler that starts all the processes for you on machines undetermined at the time of scheduling the job. Below we list the environment variables required to use each backend. Ring ^^^^^^ **MLX_RANK** should contain a single 0-based integer that defines the rank of the process. **MLX_HOSTFILE** should contain the path to a json file that contains IPs and ports for each rank to listen to, something like the following: .. code-block:: json [ ["123.123.1.1:5000", "123.123.1.2:5000"], ["123.123.2.1:5000", "123.123.2.2:5000"], ["123.123.3.1:5000", "123.123.3.2:5000"], ["123.123.4.1:5000", "123.123.4.2:5000"] ] **MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging from the distributed backend. JACCL ^^^^^ **MLX_RANK** should contain a single 0-based integer that defines the rank of the process. **MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen to all the other ranks connect to in order to establish the RDMA connections. **MLX_IBV_DEVICES** should contain the path to a json file that contains the ibverbs device names that connect each node to each other node, something like the following: .. code-block:: json [ [null, "rdma_en5", "rdma_en4", "rdma_en3"], ["rdma_en5", null, "rdma_en3", "rdma_en4"], ["rdma_en4", "rdma_en3", null, "rdma_en5"], ["rdma_en3", "rdma_en4", "rdma_en5", null] ] NCCL ^^^^^ **MLX_RANK** should contain a single 0-based integer that defines the rank of the process. **MLX_WORLD_SIZE** should contain the total number of processes that will be launched. **NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all hosts can connect to to establish the NCCL communication. **CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that corresponds to this process. Of course any `other environment variable `_ that is used by NCCL can be set. .. _tips_and_tricks: Tips and Tricks ---------------- This is a small collection of tips to help you utilize better the distributed communication capabilities of MLX. - *Test locally first.* You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small scale test on a single node first. - *Batch your communication.* As described in the :ref:`training example `, performing a lot of small communications can hurt performance. Copy the approach of :func:`mlx.nn.average_gradients` to gather many small communications in a single large one. - *Visualize the connectivity.* Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to visualize the connnections and make sure that the cables are connected correctly. See the :ref:`JACCL section ` for examples. - *Use the debugger.* ``mlx.launch`` is meant for interactive use. It broadcasts stdin to all processes and gathers stdout from all processes. This makes using ``pdb`` a breeze. ================================================ FILE: docs/src/usage/export.rst ================================================ .. _export_usage: Exporting Functions =================== .. currentmodule:: mlx.core MLX has an API to export and import functions to and from a file. This lets you run computations written in one MLX front-end (e.g. Python) in another MLX front-end (e.g. C++). This guide walks through the basics of the MLX export API with some examples. To see the full list of functions check-out the :ref:`API documentation `. Basics of Exporting ------------------- Let's start with a simple example: .. code-block:: python def fun(x, y): return x + y x = mx.array(1.0) y = mx.array(1.0) mx.export_function("add.mlxfn", fun, x, y) To export a function, provide sample input arrays that the function can be called with. The data doesn't matter, but the shapes and types of the arrays do. In the above example we exported ``fun`` with two ``float32`` scalar arrays. We can then import the function and run it: .. code-block:: python add_fun = mx.import_function("add.mlxfn") out, = add_fun(mx.array(1.0), mx.array(2.0)) # Prints: array(3, dtype=float32) print(out) out, = add_fun(mx.array(1.0), mx.array(3.0)) # Prints: array(4, dtype=float32) print(out) # Raises an exception add_fun(mx.array(1), mx.array(3.0)) # Raises an exception add_fun(mx.array([1.0, 2.0]), mx.array(3.0)) Notice the third and fourth calls to ``add_fun`` raise exceptions because the shapes and types of the inputs are different than the shapes and types of the example inputs we exported the function with. Also notice that even though the original ``fun`` returns a single output array, the imported function always returns a tuple of one or more arrays. The inputs to :func:`export_function` and to an imported function can be specified as variable positional arguments or as a tuple of arrays: .. code-block:: python def fun(x, y): return x + y x = mx.array(1.0) y = mx.array(1.0) # Both arguments to fun are positional mx.export_function("add.mlxfn", fun, x, y) # Same as above mx.export_function("add.mlxfn", fun, (x, y)) imported_fun = mx.import_function("add.mlxfn") # Ok out, = imported_fun(x, y) # Also ok out, = imported_fun((x, y)) You can pass example inputs to functions as positional or keyword arguments. If you use keyword arguments to export the function, then you have to use the same keyword arguments when calling the imported function. .. code-block:: python def fun(x, y): return x + y # One argument to fun is positional, the other is a kwarg mx.export_function("add.mlxfn", fun, x, y=y) imported_fun = mx.import_function("add.mlxfn") # Ok out, = imported_fun(x, y=y) # Also ok out, = imported_fun((x,), {"y": y}) # Raises since the keyword argument is missing out, = imported_fun(x, y) # Raises since the keyword argument has the wrong key out, = imported_fun(x, z=y) Exporting Modules ----------------- An :obj:`mlx.nn.Module` can be exported with or without the parameters included in the exported function. Here's an example: .. code-block:: python model = nn.Linear(4, 4) mx.eval(model.parameters()) def call(x): return model(x) mx.export_function("model.mlxfn", call, mx.zeros(4)) In the above example, the :obj:`mlx.nn.Linear` module is exported. Its parameters are also saved to the ``model.mlxfn`` file. .. note:: For enclosed arrays inside an exported function, be extra careful to ensure they are evaluated. The computation graph that gets exported will include the computation that produces enclosed inputs. If the above example was missing ``mx.eval(model.parameters()``, the exported function would include the random initialization of the :obj:`mlx.nn.Module` parameters. If you only want to export the ``Module.__call__`` function without the parameters, pass them as inputs to the ``call`` wrapper: .. code-block:: python model = nn.Linear(4, 4) mx.eval(model.parameters()) def call(x, **params): # Set the model's parameters to the input parameters model.update(tree_unflatten(list(params.items()))) return model(x) params = tree_flatten(model.parameters(), destination={}) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) Exporting with a Callback ------------------------- To inspect the exported graph, you can pass a callback instead of a file path to :func:`export_function`. .. code-block:: python def fun(x): return x.astype(mx.int32) def callback(args): print(args) mx.export_function(callback, fun, mx.array([1.0, 2.0])) The argument to the callback (``args``) is a dictionary which includes a ``type`` field. The possible types are: * ``"inputs"``: The ordered positional inputs to the exported function * ``"keyword_inputs"``: The keyword specified inputs to the exported function * ``"outputs"``: The ordered outputs of the exported function * ``"constants"``: Any graph constants * ``"primitives"``: Inner graph nodes representating the operations Each type has additional fields in the ``args`` dictionary. Shapeless Exports ----------------- Just like :func:`compile`, functions can also be exported for dynamically shaped inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter` to export a function which can be used for inputs with variable shapes: .. code-block:: python mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) imported_abs = mx.import_function("fun.mlxfn") # Ok out, = imported_abs(mx.array([-1.0])) # Also ok out, = imported_abs(mx.array([-1.0, -2.0])) With ``shapeless=False`` (which is the default), the second call to ``imported_abs`` would raise an exception with a shape mismatch. Shapeless exporting works the same as shapeless compilation and should be used carefully. See the :ref:`documentation on shapeless compilation ` for more information. Exporting Multiple Traces ------------------------- In some cases, functions build different computation graphs for different input arguments. A simple way to manage this is to export to a new file with each set of inputs. This is a fine option in many cases. But it can be suboptimal if the exported functions have a large amount of duplicate constant data (for example the parameters of a :obj:`mlx.nn.Module`). The export API in MLX lets you export multiple traces of the same function to a single file by creating an exporting context manager with :func:`exporter`: .. code-block:: python def fun(x, y=None): constant = mx.array(3.0) if y is not None: x += y return x + constant with mx.exporter("fun.mlxfn", fun) as exporter: exporter(mx.array(1.0)) exporter(mx.array(1.0), y=mx.array(0.0)) imported_function = mx.import_function("fun.mlxfn") # Call the function with y=None out, = imported_function(mx.array(1.0)) print(out) # Call the function with y specified out, = imported_function(mx.array(1.0), y=mx.array(1.0)) print(out) In the above example the function constant data, (i.e. ``constant``), is only saved once. Transformations with Imported Functions --------------------------------------- Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work on imported functions just like regular Python functions: .. code-block:: python def fun(x): return mx.sin(x) x = mx.array(0.0) mx.export_function("sine.mlxfn", fun, x) imported_fun = mx.import_function("sine.mlxfn") # Take the derivative of the imported function dfdx = mx.grad(lambda x: imported_fun(x)[0]) # Prints: array(1, dtype=float32) print(dfdx(x)) # Compile the imported function mx.compile(imported_fun) # Prints: array(0, dtype=float32) print(compiled_fun(x)[0]) Importing Functions in C++ -------------------------- Importing and running functions in C++ is basically the same as importing and running them in Python. First, follow the :ref:`instructions ` to setup a simple C++ project that uses MLX as a library. Next, export a simple function from Python: .. code-block:: python def fun(x, y): return mx.exp(x + y) x = mx.array(1.0) y = mx.array(1.0) mx.export_function("fun.mlxfn", fun, x, y) Import and run the function in C++ with only a few lines of code: .. code-block:: c++ auto fun = mx::import_function("fun.mlxfn"); auto inputs = {mx::array(1.0), mx::array(1.0)}; auto outputs = fun(inputs); // Prints: array(2, dtype=float32) std::cout << outputs[0] << std::endl; Imported functions can be transformed in C++ just like in Python. Use ``std::vector`` for positional arguments and ``std::map`` for keyword arguments when calling imported functions in C++. More Examples ------------- Here are a few more complete examples exporting more complex functions from Python and importing and running them in C++: * `Inference and training a multi-layer perceptron `_ ================================================ FILE: docs/src/usage/function_transforms.rst ================================================ .. _function_transforms: Function Transforms =================== .. currentmodule:: mlx.core MLX uses composable function transformations for automatic differentiation, vectorization, and compute graph optimizations. To see the complete list of function transformations check-out the :ref:`API documentation `. The key idea behind composable function transformations is that every transformation returns a function which can be further transformed. Here is a simple example: .. code-block:: shell >>> dfdx = mx.grad(mx.sin) >>> dfdx(mx.array(mx.pi)) array(-1, dtype=float32) >>> mx.cos(mx.array(mx.pi)) array(-1, dtype=float32) The output of :func:`grad` on :func:`sin` is simply another function. In this case it is the gradient of the sine function which is exactly the cosine function. To get the second derivative you can do: .. code-block:: shell >>> d2fdx2 = mx.grad(mx.grad(mx.sin)) >>> d2fdx2(mx.array(mx.pi / 2)) array(-1, dtype=float32) >>> mx.sin(mx.array(mx.pi / 2)) array(1, dtype=float32) Using :func:`grad` on the output of :func:`grad` is always ok. You keep getting higher order derivatives. Any of the MLX function transformations can be composed in any order to any depth. See the following sections for more information on :ref:`automatic differentiation ` and :ref:`automatic vectorization `. For more information on :func:`compile` see the :ref:`compile documentation `. Automatic Differentiation ------------------------- .. _auto diff: Automatic differentiation in MLX works on functions rather than on implicit graphs. .. note:: If you are coming to MLX from PyTorch, you no longer need functions like ``backward``, ``zero_grad``, and ``detach``, or properties like ``requires_grad``. The most basic example is taking the gradient of a scalar-valued function as we saw above. You can use the :func:`grad` and :func:`value_and_grad` function to compute gradients of more complex functions. By default these functions compute the gradient with respect to the first argument: .. code-block:: python def loss_fn(w, x, y): return mx.mean(mx.square(w * x - y)) w = mx.array(1.0) x = mx.array([0.5, -0.5]) y = mx.array([1.5, -1.5]) # Computes the gradient of loss_fn with respect to w: grad_fn = mx.grad(loss_fn) dloss_dw = grad_fn(w, x, y) # Prints array(-1, dtype=float32) print(dloss_dw) # To get the gradient with respect to x we can do: grad_fn = mx.grad(loss_fn, argnums=1) dloss_dx = grad_fn(w, x, y) # Prints array([-1, 1], dtype=float32) print(dloss_dx) One way to get the loss and gradient is to call ``loss_fn`` followed by ``grad_fn``, but this can result in a lot of redundant work. Instead, you should use :func:`value_and_grad`. Continuing the above example: .. code-block:: python # Computes the gradient of loss_fn with respect to w: loss_and_grad_fn = mx.value_and_grad(loss_fn) loss, dloss_dw = loss_and_grad_fn(w, x, y) # Prints array(1, dtype=float32) print(loss) # Prints array(-1, dtype=float32) print(dloss_dw) You can also take the gradient with respect to arbitrarily nested Python containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or :obj:`dict`). Suppose we wanted a weight and a bias parameter in the above example. A nice way to do that is the following: .. code-block:: python def loss_fn(params, x, y): w, b = params["weight"], params["bias"] h = w * x + b return mx.mean(mx.square(h - y)) params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} x = mx.array([0.5, -0.5]) y = mx.array([1.5, -1.5]) # Computes the gradient of loss_fn with respect to both the # weight and bias: grad_fn = mx.grad(loss_fn) grads = grad_fn(params, x, y) # Prints # {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)} print(grads) Notice the tree structure of the parameters is preserved in the gradients. In some cases you may want to stop gradients from propagating through a part of the function. You can use the :func:`stop_gradient` for that. Automatic Vectorization ----------------------- .. _vmap: Use :func:`vmap` to automate vectorizing complex functions. Here we'll go through a basic and contrived example for the sake of clarity, but :func:`vmap` can be quite powerful for more complex functions which are difficult to optimize by hand. .. warning:: Some operations are not yet supported with :func:`vmap`. If you encounter an error like: ``ValueError: Primitive's vmap not implemented.`` file an `issue `_ and include your function. We will prioritize including it. A naive way to add the elements from two sets of vectors is with a loop: .. code-block:: python xs = mx.random.uniform(shape=(4096, 100)) ys = mx.random.uniform(shape=(100, 4096)) def naive_add(xs, ys): return [xs[i] + ys[:, i] for i in range(xs.shape[0])] Instead you can use :func:`vmap` to automatically vectorize the addition: .. code-block:: python # Vectorize over the second dimension of x and the # first dimension of y vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1)) The ``in_axes`` parameter can be used to specify which dimensions of the corresponding input to vectorize over. Similarly, use ``out_axes`` to specify where the vectorized axes should be in the outputs. Let's time these two different versions: .. code-block:: python import timeit print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) On an M1 Max the naive version takes in total ``5.639`` seconds whereas the vectorized version takes only ``0.024`` seconds, more than 200 times faster. Of course, this operation is quite contrived. A better approach is to simply do ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. ================================================ FILE: docs/src/usage/indexing.rst ================================================ .. _indexing: Indexing Arrays =============== .. currentmodule:: mlx.core For the most part, indexing an MLX :obj:`array` works the same as indexing a NumPy :obj:`numpy.ndarray`. See the `NumPy documentation `_ for more details on how that works. For example, you can use regular integers and slices (:obj:`slice`) to index arrays: .. code-block:: shell >>> arr = mx.arange(10) >>> arr[3] array(3, dtype=int32) >>> arr[-2] # negative indexing works array(8, dtype=int32) >>> arr[2:8:2] # start, stop, stride array([2, 4, 6], dtype=int32) For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy: .. code-block:: shell >>> arr = mx.arange(8).reshape(2, 2, 2) >>> arr[:, :, 0] array(3, dtype=int32) array([[0, 2], [4, 6]], dtype=int32 >>> arr[..., 0] array([[0, 2], [4, 6]], dtype=int32 You can index with ``None`` to create a new axis: .. code-block:: shell >>> arr = mx.arange(8) >>> arr.shape [8] >>> arr[None].shape [1, 8] You can also use an :obj:`array` to index another :obj:`array`: .. code-block:: shell >>> arr = mx.arange(10) >>> idx = mx.array([5, 7]) >>> arr[idx] array([5, 7], dtype=int32) Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices works just as in NumPy. Other functions which may be useful for indexing arrays are :func:`take` and :func:`take_along_axis`. Differences from NumPy ---------------------- .. Note:: MLX indexing is different from NumPy indexing in two important ways: * Indexing does not perform bounds checking. Indexing out of bounds is undefined behavior. * Boolean mask based indexing is supported for assignment only (see :ref:`boolean-mask-assignment`). The reason for the lack of bounds checking is that exceptions cannot propagate from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient. Indexing with boolean masks is something that MLX may support in the future. In general, MLX has limited support for operations for which output *shapes* are dependent on input *data*. Other examples of these types of operations which MLX does not yet support include :func:`numpy.nonzero` and the single input version of :func:`numpy.where`. In Place Updates ---------------- In place updates to indexed arrays are possible in MLX. For example: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> a[2] = 0 >>> a array([1, 2, 0], dtype=int32) Just as in NumPy, in place updates will be reflected in all references to the same array: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> b = a >>> b[2] = 0 >>> b array([1, 2, 0], dtype=int32) >>> a array([1, 2, 0], dtype=int32) Note that unlike NumPy, slicing an array creates a copy, not a view. So mutating it does not mutate the original array: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> b = a[:] >>> b[2] = 0 >>> b array([1, 2, 0], dtype=int32) >>> a array([1, 2, 3], dtype=int32) Also unlike NumPy, updates to the same location are nondeterministic: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> a[[0, 0]] = mx.array([4, 5]) The first element of ``a`` could be ``4`` or ``5``. Transformations of functions which use in-place updates are allowed and work as expected. For example: .. code-block:: python def fun(x, idx): x[idx] = 2.0 return x.sum() dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) print(dfdx) # Prints: array([1, 0, 1], dtype=float32) In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` and ones elsewhere. .. _boolean-mask-assignment: Boolean Mask Assignment ----------------------- MLX supports boolean indices using NumPy syntax. A mask must already be a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``. Other index types are routed through the standard scatter code. .. code-block:: shell >>> a = mx.array([1.0, 2.0, 3.0]) >>> mask = mx.array([True, False, True]) >>> updates = mx.array([5.0, 6.0]) >>> a[mask] = updates >>> a array([5.0, 2.0, 6.0], dtype=float32) Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar assignments, ``updates`` must provide at least as many elements as there are ``True`` entries in ``mask``. .. code-block:: shell >>> a = mx.zeros((2, 3)) >>> mask = mx.array([[True, False, True], [False, False, True]]) >>> a[mask] = 1.0 >>> a array([[1.0, 0.0, 1.0], [0.0, 0.0, 1.0]], dtype=float32) Boolean masks follow NumPy semantics: - The mask shape must match the shape of the axes it indexes exactly. The only exception is a scalar boolean mask, which broadcasts to the full array. - Any axes not covered by the mask are taken in full. .. code-block:: shell >>> a = mx.arange(1000).reshape(10, 10, 10) >>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1 The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]`` selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``. Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed axes and therefore raise errors. ================================================ FILE: docs/src/usage/launching_distributed.rst ================================================ :orphan: .. _usage_launch_distributed: Launching Distributed Programs ============================== .. currentmodule:: mlx.core.distributed The MLX python package provides two utilities to help you configure your Macs for distributed computation and also launch distributed programs on multiple nodes or with many processes in a single node. These utilities are aptly named - ``mlx.launch`` - ``mlx.distributed_config`` See the :doc:`distributed docs ` for an introduction and getting-started guides to the various backends. ``mlx.distributed_config`` --------------------------- Unless you are launching distributed jobs locally for development or multi-gpu CUDA environments, then you have several Macs that you need to configure for distributed communication with MLX. ``mlx.distributed_config`` aims to automate the process of configuring the network interfaces (especially for communication over thunderbolt) and also creating the hostfile to be used with ``mlx.launch``. We will analyse 3 cases of using ``mlx.distributed_config`` 1. RDMA over thunderbolt using JACCL 2. TCP/IP over thunderbolt using the ring backend 3. TCP/IP over ethernet using the ring backend JACCL ^^^^^^^ After following :ref:`the steps to enable RDMA ` you can run the following command to configure the nodes and create the hostfile. .. code-block:: mlx.distributed_config --verbose --backend jaccl \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \ --auto-setup --output m3-ultra-jaccl.json Let's walk through the steps that the script takes to configure the nodes. 1. ssh to all nodes to verify that they are reachable 2. Extract the thunderbolt connectivity. Namely run commands on each node to calculate which node is connected to which other node. 3. Verify that we have a valid fully connected mesh 4. Check that RDMA is enabled 5. Extract the ethernet IP from interface en0 6. Disable the thunderbolt bridge and set up peer to peer networks for each thunderbolt cable 7. Write the hostfile Knowing the above steps allows you to manually configure the nodes but also debug any configuration issue. For instance changing the Ethernet IP to a different interface directly in the config is possible (as long as it is reachable from all nodes). The ``--auto-setup`` argument requires password-less sudo on each node. If it isn't available then the configuration script will print commands to be run on each node. Ring over thunderbolt ^^^^^^^^^^^^^^^^^^^^^ Setting up a ring backend over thunderbolt only requires changing the ``--backend`` from ``jaccl`` to ``ring``. The steps are very similar with the main difference being that instead of verifying that the nodes are fully connected, the script attempts to identify a ring topology (or multiple rings). Ring over Ethernet ^^^^^^^^^^^^^^^^^^ Configuring the ring backend over ethernet doesn't require setting up network interface and as such it simply extracts the ``en0`` IP from each node and writes the hostfile. Debugging cable connections ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``mlx.distributed_config`` can help you debug the connectivity of your nodes over thunderbolt by exporting a graph of the connections. Running .. code-block:: mlx.distributed_config --verbose \ --hosts host1,host2,host3,host4 \ --over thunderbolt --dot will export a `GraphViz `_ representation of the connections between the nodes which makes it very easy to figure out which cable is not connected correctly. See :ref:`the JACCL section ` for an example. ``mlx.launch`` -------------- The minimal usage example of ``mlx.launch`` is simply .. code:: shell mlx.launch --hosts ip1,ip2 my_script.py or for testing on localhost .. code:: shell mlx.launch -n 2 my_script.py The ``mlx.launch`` command connects to the provided host and launches the input script on each host. It monitors each of the launched processes and terminates the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated. It also takes care of forwarding the output of each remote process to stdout and stderr respectively. Importantly, it also broadcasts stdin to each process which enables interactive programs to work in distributed mode as well as debugging using the interactive debugger. Providing Hosts ^^^^^^^^^^^^^^^^ Hosts can be provided as command line arguments, like above, but the way that allows to fully define a list of hosts is via a JSON hostfile. The hostfile has a very simple schema. It is simply a list of objects that define each host via a hostname to ssh to and a list of IPs to utilize for the communication. .. code:: json [ {"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]}, {"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]} ] You can use ``mlx.distributed_config --over ethernet`` to create a hostfile with IPs corresponding to the ``en0`` interface. Setting up Remote Hosts ^^^^^^^^^^^^^^^^^^^^^^^^ In order to be able to launch the script on each host we need to be able to connect via ssh. Moreover the input script and python binary need to be on each host and on the same path. A good checklist to debug errors is the following: * ``ssh hostname`` works without asking for password or host confirmation * the python binary is available on all hosts at the same path. You can use ``mlx.launch --print-python`` to see what that path is. * the script you want to run is available on all hosts at the same path If you are launching from a node with a completely different setup than the nodes that the program will run on, you can specify ``--no-verify-script`` so that ``mlx.launch`` does not attempt to verify that the executable and script exist locally before launching the distributed job. .. _ring_specifics: Ring Specifics ^^^^^^^^^^^^^^ The :ref:`ring ` backend, which is also the default backend, can be explicitly selected with the argument ``--backend ring``. The ring backend has some specific requirements and arguments that are different to other backends: * The argument ``--hosts`` only accepts IPs and not hostnames. If we need to ssh to a hostname that does not correspond to the IP we want to bind to we have to provide a hostfile. * ``--starting-port`` defines the port to bind to on the remote hosts. Specifically rank 0 for the first IP will use this port and each subsequent IP or rank will add 1 to this port. * ``--connections-per-ip`` allows us to increase the number of connections between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for ``mpirun``. .. _jaccl_specifics: JACCL Specifics ^^^^^^^^^^^^^^^^ The :ref:`JACCL ` backend can be selected with the argument ``--backend jaccl``. A hostfile is necessary to launch with this backend because it needs to contain the RDMA devices connecting each node to each other node. NCCL Specifics ^^^^^^^^^^^^^^ The :ref:`NCCL ` backend is the default backend for CUDA environments. When launching from a Mac to a Linux machine with CUDA then the backend should be selected using ``--backend nccl``. The ``--repeat-hosts, -n`` argument should be used to launch multi-node and multi-gpu jobs. For instance .. code-block:: mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh will attempt to launch 16 processes, 8 on each node that will all run ``my-job.sh``. .. _mpi_specifics: MPI Specifics ^^^^^^^^^^^^^ One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case, ``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover, * The IPs in the hostfile are ignored * The ssh connectivity requirement is stronger as every node needs to be able to connect to every other node * ``mpirun`` needs to be available on every node at the same path Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance to choose a specific interface for the byte-transfer-layer of MPI we can call ``mlx.launch`` as follows: .. code:: shell mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py ================================================ FILE: docs/src/usage/lazy_evaluation.rst ================================================ .. _lazy eval: Lazy Evaluation =============== .. currentmodule:: mlx.core Why Lazy Evaluation ------------------- When you perform operations in MLX, no computation actually happens. Instead a compute graph is recorded. The actual computation only happens if an :func:`eval` is performed. MLX uses lazy evaluation because it has some nice features, some of which we describe below. Transforming Compute Graphs ^^^^^^^^^^^^^^^^^^^^^^^^^^^ Lazy evaluation lets us record a compute graph without actually doing any computations. This is useful for function transformations like :func:`grad` and :func:`vmap` and graph optimizations. Currently, MLX does not compile and rerun compute graphs. They are all generated dynamically. However, lazy evaluation makes it much easier to integrate compilation for future performance enhancements. Only Compute What You Use ^^^^^^^^^^^^^^^^^^^^^^^^^ In MLX you do not need to worry as much about computing outputs that are never used. For example: .. code-block:: python def fun(x): a = fun1(x) b = expensive_fun(a) return a, b y, _ = fun(x) Here, we never actually compute the output of ``expensive_fun``. Use this pattern with care though, as the graph of ``expensive_fun`` is still built, and that has some cost associated to it. Similarly, lazy evaluation can be beneficial for saving memory while keeping code simple. Say you have a very large model ``Model`` derived from :obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``. Typically, this will initialize all of the weights as ``float32``, but the initialization does not actually compute anything until you perform an :func:`eval`. If you update the model with ``float16`` weights, your maximum consumed memory will be half that required if eager computation was used instead. This pattern is simple to do in MLX thanks to lazy computation: .. code-block:: python model = Model() # no memory used yet model.load_weights("weights_fp16.safetensors") When to Evaluate ---------------- A common question is when to use :func:`eval`. The trade-off is between letting graphs get too large and not batching enough useful work. For example: .. code-block:: python for _ in range(100): a = a + b mx.eval(a) b = b * 2 mx.eval(b) This is a bad idea because there is some fixed overhead with each graph evaluation. On the other hand, there is some slight overhead which grows with the compute graph size, so extremely large graphs (while computationally correct) can be costly. Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay. Most numerical computations have an iterative outer loop (e.g. the iteration in stochastic gradient descent). A natural and usually efficient place to use :func:`eval` is at each iteration of this outer loop. Here is a concrete example: .. code-block:: python for batch in dataset: # Nothing has been evaluated yet loss, grad = value_and_grad_fn(model, batch) # Still nothing has been evaluated optimizer.update(model, grad) # Evaluate the loss and the new parameters which will # run the full gradient computation and optimizer update mx.eval(loss, model.parameters()) An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you ``print`` an array, convert it to an :obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`, the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX saving functions) will also evaluate the array. Calling :func:`array.item` on a scalar array will also evaluate it. In the example above, printing the loss (``print(loss)``) or adding the loss scalar to a list (``losses.append(loss.item())``) would cause a graph evaluation. If these lines are before ``mx.eval(loss, model.parameters())`` then this will be a partial evaluation, computing only the forward pass. Also, calling :func:`eval` on an array or set of arrays multiple times is perfectly fine. This is effectively a no-op. .. warning:: Using scalar arrays for control-flow will cause an evaluation. Here is an example: .. code-block:: python def fun(x): h, y = first_layer(x) if y > 0: # An evaluation is done here! z = second_layer_a(h) else: z = second_layer_b(h) return z Using arrays for control flow should be done with care. The above example works and can even be used with gradient transformations. However, this can be very inefficient if evaluations are done too frequently. ================================================ FILE: docs/src/usage/numpy.rst ================================================ .. _numpy: Conversion to NumPy and Other Frameworks ======================================== MLX array supports conversion between other frameworks with either: * The `Python Buffer Protocol `_. * `DLPack `_. Let's convert an array to NumPy and back. .. code-block:: python import mlx.core as mx import numpy as np a = mx.arange(3) b = np.array(a) # copy of a c = mx.array(b) # copy of b .. note:: Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.`` By default, NumPy copies data to a new array. This can be prevented by creating an array view: .. code-block:: python a = mx.arange(3) a_view = np.array(a, copy=False) print(a_view.flags.owndata) # False a_view[0] = 1 print(a[0].item()) # 1 .. note:: NumPy arrays with type ``float64`` will be default converted to MLX arrays with type ``float32``. A NumPy array view is a normal NumPy array, except that it does not own its memory. This means writing to the view is reflected in the original array. While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients. Let's demonstrate this in an example: .. code-block:: python def f(x): x_view = np.array(x, copy=False) x_view[:] *= x_view # modify memory without telling mx return x.sum() x = mx.array([3.0]) y, df = mx.value_and_grad(f)(x) print("f(x) = x² =", y.item()) # 9.0 print("f'(x) = 2x !=", df.item()) # 1.0 The function ``f`` indirectly modifies the array ``x`` through a memory view. However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``, representing the gradient of the sum operation alone. The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated. It's important to note that a similar issue arises during array conversion and copying. For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, even though no in-place operations on MLX memory are executed. PyTorch ------- .. warning:: PyTorch Support for :obj:`memoryview` is experimental and can break for multi-dimensional arrays. Casting to NumPy first is advised for now. PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. .. code-block:: python import mlx.core as mx import torch a = mx.arange(3) b = torch.tensor(memoryview(a)) c = mx.array(b) JAX --- JAX fully supports the buffer protocol. .. code-block:: python import mlx.core as mx import jax.numpy as jnp a = mx.arange(3) b = jnp.array(a) c = mx.array(b) TensorFlow ---------- TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`. .. code-block:: python import mlx.core as mx import tensorflow as tf a = mx.arange(3) b = tf.constant(memoryview(a)) c = mx.array(b) ================================================ FILE: docs/src/usage/quick_start.rst ================================================ Quick Start Guide ================= Basics ------ .. currentmodule:: mlx.core Import ``mlx.core`` and make an :class:`array`: .. code-block:: python >> import mlx.core as mx >> a = mx.array([1, 2, 3, 4]) >> a.shape [4] >> a.dtype int32 >> b = mx.array([1.0, 2.0, 3.0, 4.0]) >> b.dtype float32 Operations in MLX are lazy. The outputs of MLX operations are not computed until they are needed. To force an array to be evaluated use :func:`eval`. Arrays will automatically be evaluated in a few cases. For example, inspecting a scalar with :meth:`array.item`, printing an array, or converting an array from :class:`array` to :class:`numpy.ndarray` all automatically evaluate the array. .. code-block:: python >> c = a + b # c not yet evaluated >> mx.eval(c) # evaluates c >> c = a + b >> print(c) # Also evaluates c array([2, 4, 6, 8], dtype=float32) >> c = a + b >> import numpy as np >> np.array(c) # Also evaluates c array([2., 4., 6., 8.], dtype=float32) See the page on :ref:`Lazy Evaluation ` for more details. Function and Graph Transformations ---------------------------------- MLX has standard function transformations like :func:`grad` and :func:`vmap`. Transformations can be composed arbitrarily. For example ``grad(vmap(grad(fn)))`` (or any other composition) is allowed. .. code-block:: python >> x = mx.array(0.0) >> mx.sin(x) array(0, dtype=float32) >> mx.grad(mx.sin)(x) array(1, dtype=float32) >> mx.grad(mx.grad(mx.sin))(x) array(-0, dtype=float32) Other gradient transformations include :func:`vjp` for vector-Jacobian products and :func:`jvp` for Jacobian-vector products. Use :func:`value_and_grad` to efficiently compute both a function's output and gradient with respect to the function's input. ================================================ FILE: docs/src/usage/saving_and_loading.rst ================================================ .. _saving_and_loading: Saving and Loading Arrays ========================= .. currentmodule:: mlx.core MLX supports multiple array serialization formats. .. list-table:: Serialization Formats :widths: 20 8 25 25 :header-rows: 1 * - Format - Extension - Function - Notes * - NumPy - ``.npy`` - :func:`save` - Single arrays only * - NumPy archive - ``.npz`` - :func:`savez` and :func:`savez_compressed` - Multiple arrays * - Safetensors - ``.safetensors`` - :func:`save_safetensors` - Multiple arrays * - GGUF - ``.gguf`` - :func:`save_gguf` - Multiple arrays The :func:`load` function will load any of the supported serialization formats. It determines the format from the extensions. The output of :func:`load` depends on the format. Here's an example of saving a single array to a file: .. code-block:: shell >>> a = mx.array([1.0]) >>> mx.save("array", a) The array ``a`` will be saved in the file ``array.npy`` (notice the extension is automatically added). Including the extension is optional; if it is missing it will be added. You can load the array with: .. code-block:: shell >>> mx.load("array.npy") array([1], dtype=float32) Here's an example of saving several arrays to a single file: .. code-block:: shell >>> a = mx.array([1.0]) >>> b = mx.array([2.0]) >>> mx.savez("arrays", a, b=b) For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays as arguments. If the keywords are missing, then default names will be provided. This can be loaded with: .. code-block:: shell >>> mx.load("arrays.npz") {'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)} In this case :func:`load` returns a dictionary of names to arrays. The functions :func:`save_safetensors` and :func:`save_gguf` are similar to :func:`savez`, but they take as input a :obj:`dict` of string names to arrays: .. code-block:: shell >>> a = mx.array([1.0]) >>> b = mx.array([2.0]) >>> mx.save_safetensors("arrays", {"a": a, "b": b}) ================================================ FILE: docs/src/usage/unified_memory.rst ================================================ .. _unified_memory: Unified Memory ============== .. currentmodule:: mlx.core Apple silicon has a unified memory architecture. The CPU and GPU have direct access to the same memory pool. MLX is designed to take advantage of that. Concretely, when you make an array in MLX you don't have to specify its location: .. code-block:: python a = mx.random.normal((100,)) b = mx.random.normal((100,)) Both ``a`` and ``b`` live in unified memory. In MLX, rather than moving arrays to devices, you specify the device when you run the operation. Any device can perform any operation on ``a`` and ``b`` without needing to move them from one memory location to another. For example: .. code-block:: python mx.add(a, b, stream=mx.cpu) mx.add(a, b, stream=mx.gpu) In the above, both the CPU and the GPU will perform the same add operation. The operations can (and likely will) be run in parallel since there are no dependencies between them. See :ref:`using_streams` for more information the semantics of streams in MLX. In the above ``add`` example, there are no dependencies between operations, so there is no possibility for race conditions. If there are dependencies, the MLX scheduler will automatically manage them. For example: .. code-block:: python c = mx.add(a, b, stream=mx.cpu) d = mx.add(a, c, stream=mx.gpu) In the above case, the second ``add`` runs on the GPU but it depends on the output of the first ``add`` which is running on the CPU. MLX will automatically insert a dependency between the two streams so that the second ``add`` only starts executing after the first is complete and ``c`` is available. A Simple Example ~~~~~~~~~~~~~~~~ Here is a more interesting (albeit slightly contrived example) of how unified memory can be helpful. Suppose we have the following computation: .. code-block:: python def fun(a, b, d1, d2): x = mx.matmul(a, b, stream=d1) for _ in range(500): b = mx.exp(b, stream=d2) return x, b which we want to run with the following arguments: .. code-block:: python a = mx.random.uniform(shape=(4096, 512)) b = mx.random.uniform(shape=(512, 4)) The first ``matmul`` operation is a good fit for the GPU since it's more compute dense. The second sequence of operations are a better fit for the CPU, since they are very small and would probably be overhead bound on the GPU. If we time the computation fully on the GPU, we get 2.8 milliseconds. But if we run the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only about 1.4 milliseconds, about twice as fast. These times were measured on an M1 Max. ================================================ FILE: docs/src/usage/using_streams.rst ================================================ .. _using_streams: Using Streams ============= .. currentmodule:: mlx.core Specifying the :obj:`Stream` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ All operations (including random number generation) take an optional keyword argument ``stream``. The ``stream`` kwarg specifies which :obj:`Stream` the operation should run on. If the stream is unspecified then the operation is run on the default stream of the default device: ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is run on the default stream of the provided device ``mx.default_stream(my_device)``. ================================================ FILE: examples/cmake_project/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.27) project(example LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # Comment the following two commands only the MLX C++ library is installed and # set(MLX_ROOT "/path/to/mlx") directly if needed. find_package( Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE MLX_ROOT) find_package(MLX CONFIG REQUIRED) add_executable(example example.cpp) target_link_libraries(example PRIVATE mlx) ================================================ FILE: examples/cmake_project/README.md ================================================ ## Build and Run Install MLX with Python: ```bash pip install mlx>=0.22 ``` Build the C++ example: ```bash cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build ``` Run the C++ example: ``` ./build/example ``` which should output: ``` array([2, 4, 6], dtype=int32) ``` ================================================ FILE: examples/cmake_project/example.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/mlx.h" namespace mx = mlx::core; int main() { auto x = mx::array({1, 2, 3}); auto y = mx::array({1, 2, 3}); std::cout << x + y << std::endl; return 0; } ================================================ FILE: examples/cpp/CMakeLists.txt ================================================ function(build_example SRCFILE) get_filename_component(src_name ${SRCFILE} NAME_WE) set(target "${src_name}") add_executable(${target} ${SRCFILE}) target_link_libraries(${target} PRIVATE mlx) endfunction(build_example) build_example(tutorial.cpp) build_example(linear_regression.cpp) build_example(logistic_regression.cpp) build_example(metal_capture.cpp) build_example(distributed.cpp) ================================================ FILE: examples/cpp/distributed.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/mlx.h" namespace mx = mlx::core; int main() { if (!mx::distributed::is_available()) { std::cout << "No communication backend found" << std::endl; return 1; } auto global_group = mx::distributed::init(); std::cout << global_group.rank() << " / " << global_group.size() << std::endl; mx::array x = mx::ones({10}); mx::array out = mx::distributed::all_sum(x, global_group); std::cout << out << std::endl; } ================================================ FILE: examples/cpp/linear_regression.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/mlx.h" #include "timer.h" /** * An example of linear regression with MLX. */ namespace mx = mlx::core; int main() { int num_features = 100; int num_examples = 1'000; int num_iters = 10'000; float learning_rate = 0.01; // True parameters auto w_star = mx::random::normal({num_features}); // The input examples (design matrix) auto X = mx::random::normal({num_examples, num_features}); // Noisy labels auto eps = 1e-2 * mx::random::normal({num_examples}); auto y = mx::matmul(X, w_star) + eps; // Initialize random parameters mx::array w = 1e-2 * mx::random::normal({num_features}); auto loss_fn = [&](mx::array w) { auto yhat = mx::matmul(X, w); return (0.5f / num_examples) * mx::sum(mx::square(yhat - y)); }; auto grad_fn = mx::grad(loss_fn); auto tic = timer::time(); for (int it = 0; it < num_iters; ++it) { auto grads = grad_fn(w); w = w - learning_rate * grads; mx::eval(w); } auto toc = timer::time(); auto loss = loss_fn(w); auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item()); auto throughput = num_iters / timer::seconds(toc - tic); std::cout << "Loss " << loss << ", |w - w*| = " << error_norm << ", Throughput " << throughput << " (it/s)." << std::endl; } ================================================ FILE: examples/cpp/logistic_regression.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/mlx.h" #include "timer.h" /** * An example of logistic regression with MLX. */ namespace mx = mlx::core; int main() { int num_features = 100; int num_examples = 1'000; int num_iters = 10'000; float learning_rate = 0.1; // True parameters auto w_star = mx::random::normal({num_features}); // The input examples auto X = mx::random::normal({num_examples, num_features}); // Labels auto y = mx::matmul(X, w_star) > 0; // Initialize random parameters mx::array w = 1e-2 * mx::random::normal({num_features}); auto loss_fn = [&](mx::array w) { auto logits = mx::matmul(X, w); auto scale = (1.0f / num_examples); return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); }; auto grad_fn = mx::grad(loss_fn); auto tic = timer::time(); for (int it = 0; it < num_iters; ++it) { auto grads = grad_fn(w); w = w - learning_rate * grads; mx::eval(w); } auto toc = timer::time(); auto loss = loss_fn(w); auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; auto throughput = num_iters / timer::seconds(toc - tic); std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " << throughput << " (it/s)." << std::endl; } ================================================ FILE: examples/cpp/metal_capture.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/mlx.h" namespace mx = mlx::core; int main() { // To use Metal debugging and profiling: // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 2. Run with MTL_CAPTURE_ENABLED=1. mx::metal::start_capture("mlx_trace.gputrace"); // Start at index two because the default GPU and CPU streams have indices // zero and one, respectively. This naming matches the label assigned to each // stream's command queue. auto s2 = new_stream(mx::Device::gpu); auto s3 = new_stream(mx::Device::gpu); auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); auto x = mx::add(a, a, s2); auto y = mx::add(b, b, s3); // The multiply will happen on the default stream. std::cout << mx::multiply(x, y) << std::endl; mx::metal::stop_capture(); } ================================================ FILE: examples/cpp/timer.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include namespace timer { using namespace std::chrono; template inline double seconds(duration x) { return duration_cast(x).count() / 1e9; } inline auto time() { return high_resolution_clock::now(); } } // namespace timer ================================================ FILE: examples/cpp/tutorial.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include "mlx/mlx.h" namespace mx = mlx::core; void array_basics() { // Make a scalar array: mx::array x(1.0); // Get the value out of it: auto s = x.item(); assert(s == 1.0); // Scalars have a size of 1: size_t size = x.size(); assert(size == 1); // Scalars have 0 dimensions: int ndim = x.ndim(); assert(ndim == 0); // The shape should be an empty vector: auto shape = x.shape(); assert(shape.empty()); // The datatype should be float32: auto dtype = x.dtype(); assert(dtype == mx::float32); // Specify the dtype when constructing the array: x = mx::array(1, mx::int32); assert(x.dtype() == mx::int32); x.item(); // OK // x.item(); // Undefined! // Make a multidimensional array: x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); // mlx is row-major by default so the first row of this array // is [1.0, 2.0] and the second row is [3.0, 4.0] // Make an array of shape {2, 2} filled with ones: auto y = mx::ones({2, 2}); // Pointwise add x and y: auto z = mx::add(x, y); // Same thing: z = x + y; // mlx is lazy by default. At this point `z` only // has a shape and a type but no actual data: assert(z.dtype() == mx::float32); assert(z.shape(0) == 2); assert(z.shape(1) == 2); // To actually run the computation you must evaluate `z`. // Under the hood, mlx records operations in a graph. // The variable `z` is a node in the graph which points to its operation // and inputs. When `eval` is called on an array (or arrays), the array and // all of its dependencies are recursively evaluated to produce the result. // Once an array is evaluated, it has data and is detached from its inputs. mx::eval(z); // Of course the array can still be an input to other operations. You can // even call eval on the array again, this will just be a no-op: mx::eval(z); // no-op // Some functions or methods on arrays implicitly evaluate them. For example // accessing a value in an array or printing the array implicitly evaluate it: z = mx::ones({1}); z.item(); // implicit evaluation z = mx::ones({2, 2}); std::cout << z << std::endl; // implicit evaluation } void automatic_differentiation() { auto fn = [](mx::array x) { return mx::square(x); }; // Computing the derivative function of a function auto grad_fn = mx::grad(fn); // Call grad_fn on the input to get the derivative auto x = mx::array(1.5); auto dfdx = grad_fn(x); // dfdx is 2 * x // Get the second derivative by composing grad with grad auto d2fdx2 = mx::grad(mx::grad(fn))(x); // d2fdx2 is 2 } int main() { array_basics(); automatic_differentiation(); } ================================================ FILE: examples/export/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.27) project(import_mlx LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package( Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE MLX_ROOT) find_package(MLX CONFIG REQUIRED) add_executable(eval_mlp eval_mlp.cpp) target_link_libraries(eval_mlp PRIVATE mlx) add_executable(train_mlp train_mlp.cpp) target_link_libraries(train_mlp PRIVATE mlx) ================================================ FILE: examples/export/README.md ================================================ ## Setup Install MLX: ```bash pip install mlx>=0.22 ``` Build the C++ examples: ```bash cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build ``` ## Run ### Eval MLP Run the Python script to export the eval function: ```bash python eval_mlp.py ``` Then run the C++ program to import and run the function: ``` ./build/eval_mlp ``` The Python and C++ programs should output the same result. ### Train MLP Run the Python script to export the model initialization and training functions: ```bash python train_mlp.py ``` Then run the C++ program to import and run the functions: ``` ./build/train_mlp ``` The Python and C++ programs should output the same results. ================================================ FILE: examples/export/eval_mlp.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include namespace mx = mlx::core; int main() { int batch_size = 8; int input_dim = 32; // Make the input mx::random::seed(42); auto example_x = mx::random::uniform({batch_size, input_dim}); // Import the function auto forward = mx::import_function("eval_mlp.mlxfn"); // Call the imported function auto out = forward({example_x})[0]; std::cout << out << std::endl; return 0; } ================================================ FILE: examples/export/eval_mlp.py ================================================ # Copyright © 2024 Apple Inc. import mlx.core as mx import mlx.nn as nn import mlx.utils class MLP(nn.Module): """A simple MLP.""" def __init__( self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int ): super().__init__() layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] self.layers = [ nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) ] def __call__(self, x): for l in self.layers[:-1]: x = nn.relu(l(x)) return self.layers[-1](x) if __name__ == "__main__": batch_size = 8 input_dim = 32 output_dim = 10 # Load the model mx.random.seed(0) # Seed for params model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim) mx.eval(model) # Note, the model parameters are saved in the export function def forward(x): return model(x) mx.random.seed(42) # Seed for input example_x = mx.random.uniform(shape=(batch_size, input_dim)) mx.export_function("eval_mlp.mlxfn", forward, example_x) # Import in Python imported_forward = mx.import_function("eval_mlp.mlxfn") expected = forward(example_x) (out,) = imported_forward(example_x) assert mx.allclose(expected, out) print(out) ================================================ FILE: examples/export/train_mlp.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include namespace mx = mlx::core; int main() { int batch_size = 8; int input_dim = 32; int output_dim = 10; auto state = mx::import_function("init_mlp.mlxfn")({}); // Make the input mx::random::seed(42); auto example_X = mx::random::normal({batch_size, input_dim}); auto example_y = mx::random::randint(0, output_dim, {batch_size}); // Import the function auto step = mx::import_function("train_mlp.mlxfn"); // Call the imported function for (int it = 0; it < 100; ++it) { state.insert(state.end(), {example_X, example_y}); state = step(state); eval(state); auto loss = state.back(); state.pop_back(); if (it % 10 == 0) { std::cout << "Loss " << loss.item() << std::endl; } } return 0; } ================================================ FILE: examples/export/train_mlp.py ================================================ # Copyright © 2024 Apple Inc. import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import mlx.utils class MLP(nn.Module): """A simple MLP.""" def __init__( self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int ): super().__init__() layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] self.layers = [ nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) ] def __call__(self, x): for l in self.layers[:-1]: x = nn.relu(l(x)) return self.layers[-1](x) if __name__ == "__main__": batch_size = 8 input_dim = 32 output_dim = 10 def init(): # Seed for the parameter initialization mx.random.seed(0) model = MLP( num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim ) optimizer = optim.SGD(learning_rate=1e-1) optimizer.init(model.parameters()) state = [model.parameters(), optimizer.state] tree_structure, state = zip(*mlx.utils.tree_flatten(state)) return model, optimizer, tree_structure, state # Export the model parameter initialization model, optimizer, tree_structure, state = init() mx.eval(state) mx.export_function("init_mlp.mlxfn", lambda: init()[-1]) def loss_fn(params, X, y): model.update(params) return nn.losses.cross_entropy(model(X), y, reduction="mean") def step(*inputs): *state, X, y = inputs params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state))) optimizer.state = opt_state loss, grads = mx.value_and_grad(loss_fn)(params, X, y) params = optimizer.apply_gradients(grads, params) _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state])) return *state, loss # Make some random data mx.random.seed(42) example_X = mx.random.normal(shape=(batch_size, input_dim)) example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,)) mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y) # Export one step of SGD imported_step = mx.import_function("train_mlp.mlxfn") for it in range(100): *state, loss = imported_step(*state, example_X, example_y) if it % 10 == 0: print(f"Loss {loss.item():.6}") ================================================ FILE: examples/extensions/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.27) project(_ext LANGUAGES CXX) # ----------------------------- Setup ----------------------------- set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) # ----------------------------- Dependencies ----------------------------- find_package( Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT) find_package(nanobind CONFIG REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE MLX_ROOT) find_package(MLX CONFIG REQUIRED) # ----------------------------- Extensions ----------------------------- # Add library add_library(mlx_ext) # Add sources target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp) # Add include headers target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}) # Link to mlx target_link_libraries(mlx_ext PUBLIC mlx) # ----------------------------- Metal ----------------------------- # Build metallib if(MLX_BUILD_METAL) mlx_build_metallib( TARGET mlx_ext_metallib TITLE mlx_ext SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) add_dependencies(mlx_ext mlx_ext_metallib) endif() # ----------------------------- Python Bindings ----------------------------- nanobind_add_module( _ext NB_STATIC STABLE_ABI LTO NOMINSIZE NB_DOMAIN mlx ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp) target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) endif() ================================================ FILE: examples/extensions/README.md ================================================ ## Build ``` pip install -e . ``` For faster builds during development, you can also pre-install the requirements: ``` pip install -r requirements.txt ``` And then run: ``` python setup.py build_ext -j8 --inplace ``` ## Test ``` python test.py ``` ================================================ FILE: examples/extensions/axpby/axpby.cpp ================================================ // Copyright © 2023-2025 Apple Inc. #include #include #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/utils.h" #include "axpby/axpby.h" #ifdef _METAL_ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #endif namespace my_ext { // A helper function to find the location of the current binary on disk. // The Metal library ("mlx_ext.mtllib"), should be in the same directory. std::string current_binary_dir() { static std::string binary_dir = []() { Dl_info info; if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { throw std::runtime_error("Unable to get current binary dir."); } return std::filesystem::path(info.dli_fname).parent_path().string(); }(); return binary_dir; } /////////////////////////////////////////////////////////////////////////////// // Operation Implementation /////////////////////////////////////////////////////////////////////////////// /** * Scale and sum two vectors element-wise * z = alpha * x + beta * y * * Follow numpy style broadcasting between x and y * Inputs are upcasted to floats if needed **/ mx::array axpby( const mx::array& x, // Input mx::array x const mx::array& y, // Input mx::array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { // Promote dtypes between x and y as needed auto promoted_dtype = promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs x and y auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) ? promoted_dtype : promote_types(promoted_dtype, mx::float32); // Cast x and y up to the determined dtype (on the same stream s) auto x_casted = mx::astype(x, out_dtype, s); auto y_casted = mx::astype(y, out_dtype, s); // Broadcast the shapes of x and y (on the same stream s) auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); auto out_shape = broadcasted_inputs[0].shape(); // Construct the array as the output of the Axpby primitive // with the broadcasted and upcasted arrays as inputs return mx::array( /* const mx::Shape& shape = */ out_shape, /* mx::Dtype dtype = */ out_dtype, /* std::shared_ptr primitive = */ std::make_shared(to_stream(s), alpha, beta), /* const std::vector& inputs = */ broadcasted_inputs); } /////////////////////////////////////////////////////////////////////////////// // Primitive Common Backend Implementation /////////////////////////////////////////////////////////////////////////////// template void axpby_impl( const mx::array& x, const mx::array& y, mx::array& out, float alpha_, float beta_, mx::Stream stream) { out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); encoder.set_input_array(x); encoder.set_input_array(y); encoder.set_output_array(out); // Launch the CPU kernel encoder.dispatch([x_ptr = x.data(), y_ptr = y.data(), out_ptr = out.data(), size = out.size(), shape = out.shape(), x_strides = x.strides(), y_strides = y.strides(), alpha_, beta_]() { // Cast alpha and beta to the relevant types T alpha = static_cast(alpha_); T beta = static_cast(beta_); // Do the element-wise operation for each output for (size_t out_idx = 0; out_idx < size; out_idx++) { // Map linear indices to offsets in x and y auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); // We allocate the output to be contiguous and regularly strided // (defaults to row major) and hence it doesn't need additional mapping out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; } }); } void Axpby::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == mx::float32) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::float16) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::bfloat16) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::complex64) { return axpby_impl(x, y, out, alpha_, beta_, stream()); } else { throw std::runtime_error( "Axpby is only supported for floating point types."); } } /////////////////////////////////////////////////////////////////////////////// // Primitive Metal Backend Implementation /////////////////////////////////////////////////////////////////////////////// #ifdef _METAL_ /** Evaluate primitive on GPU */ void Axpby::eval_gpu( const std::vector& inputs, std::vector& outputs) { // Prepare inputs auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Each primitive carries the stream it should execute on // and each stream carries its device identifiers auto& s = stream(); // We get the needed metal device using the stream auto& d = mx::metal::device(s.device); // Prepare to specialize based on contiguity bool contiguous_kernel = (x.flags().row_contiguous && y.flags().row_contiguous) || (x.flags().col_contiguous && y.flags().col_contiguous); // Allocate output memory with strides based on specialization if (contiguous_kernel) { out.set_data( mx::allocator::malloc(x.data_size() * out.itemsize()), x.data_size(), x.strides(), x.flags()); } else { out.set_data(mx::allocator::malloc(out.nbytes())); } // Resolve name of kernel (corresponds to axpby.metal) std::string kname = "axpby_"; kname += (contiguous_kernel ? "contiguous_" : "general_"); kname += type_to_name(out); // Load the metal library auto lib = d.get_library("mlx_ext", current_binary_dir()); // Make a kernel from this metal library auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to // those in the kernel declaration at axpby.metal int ndim = out.ndim(); size_t nelem = out.size(); // Encode input arrays to kernel compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(y, 1); // Encode output arrays to kernel compute_encoder.set_output_array(out, 2); // Encode alpha and beta compute_encoder.set_bytes(alpha_, 3); compute_encoder.set_bytes(beta_, 4); // Encode shape, strides and ndim if needed if (!contiguous_kernel) { compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder.set_vector_bytes(y.strides(), 7); compute_encoder.set_bytes(ndim, 8); } // We launch 1 thread for each input and make sure that the number of // threads in any given threadgroup is not higher than the max allowed size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup()); // Fix the 3D size of each threadgroup (in terms of threads) MTL::Size group_dims = MTL::Size(tgp_size, 1, 1); // Fix the 3D size of the launch grid (in terms of threads) MTL::Size grid_dims = MTL::Size(nelem, 1, 1); // Launch the grid with the given number of threads divided among // the given threadgroups compute_encoder.dispatch_threads(grid_dims, group_dims); } #else // Metal is not available /** Fail evaluation on GPU */ void Axpby::eval_gpu( const std::vector& inputs, std::vector& out) { throw std::runtime_error("Axpby has no GPU implementation."); } #endif /////////////////////////////////////////////////////////////////////////////// // Primitive Transforms /////////////////////////////////////////////////////////////////////////////// /** The Jacobian-vector product. */ std::vector Axpby::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { // Forward mode diff that pushes along the tangents // The jvp transform on the primitive can built with ops // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the // jvp is just the tangent scaled by alpha // Similarly, if argnums = {1}, the jvp is just the tangent // scaled by beta if (argnums.size() > 1) { auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale_arr = mx::array(scale, tangents[0].dtype()); return {mx::multiply(scale_arr, tangents[0], stream())}; } // If, argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; } } /** The vector-Jacobian product. */ std::vector Axpby::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector&) { // Reverse mode diff std::vector vjps; for (auto arg : argnums) { auto scale = arg == 0 ? alpha_ : beta_; auto scale_arr = mx::array(scale, cotangents[0].dtype()); vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream())); } return vjps; } /** Vectorize primitive along given axis */ std::pair, std::vector> Axpby::vmap( const std::vector& inputs, const std::vector& axes) { throw std::runtime_error("Axpby has no vmap implementation."); } /** Equivalence check **/ bool Axpby::is_equivalent(const Primitive& other) const { const Axpby& r_other = static_cast(other); return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; } } // namespace my_ext ================================================ FILE: examples/extensions/axpby/axpby.h ================================================ // Copyright © 2023-2025 Apple Inc. #pragma once #include "mlx/ops.h" #include "mlx/primitives.h" namespace mx = mlx::core; namespace my_ext { /////////////////////////////////////////////////////////////////////////////// // Operation /////////////////////////////////////////////////////////////////////////////// /** * Scale and sum two vectors element-wise * z = alpha * x + beta * y * * Follow numpy style broadcasting between x and y * Inputs are upcasted to floats if needed **/ mx::array axpby( const mx::array& x, // Input array x const mx::array& y, // Input array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y mx::StreamOrDevice s = {} // Stream on which to schedule the operation ); /////////////////////////////////////////////////////////////////////////////// // Primitive /////////////////////////////////////////////////////////////////////////////// class Axpby : public mx::Primitive { public: explicit Axpby(mx::Stream stream, float alpha, float beta) : mx::Primitive(stream), alpha_(alpha), beta_(beta) {}; /** * A primitive must know how to evaluate itself on the CPU/GPU * for the given inputs and populate the output array. * * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ void eval_cpu( const std::vector& inputs, std::vector& outputs) override; void eval_gpu( const std::vector& inputs, std::vector& outputs) override; /** The Jacobian-vector product. */ std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; /** The vector-Jacobian product. */ std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; /** * The primitive must know how to vectorize itself across * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; /** The name of primitive. */ const char* name() const override { return "Axpby"; } /** Equivalence check **/ bool is_equivalent(const mx::Primitive& other) const override; private: float alpha_; float beta_; }; } // namespace my_ext ================================================ FILE: examples/extensions/axpby/axpby.metal ================================================ // Copyright © 2023-2025 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" template [[kernel]] void axpby_general( device const T* x [[buffer(0)]], device const T* y [[buffer(1)]], device T* out [[buffer(2)]], constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], constant const int* shape [[buffer(5)]], constant const int64_t* x_strides [[buffer(6)]], constant const int64_t* y_strides [[buffer(7)]], constant const int& ndim [[buffer(8)]], uint index [[thread_position_in_grid]]) { auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim); out[index] = static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; } template [[kernel]] void axpby_contiguous( device const T* x [[buffer(0)]], device const T* y [[buffer(1)]], device T* out [[buffer(2)]], constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], uint index [[thread_position_in_grid]]) { out[index] = static_cast(alpha) * x[index] + static_cast(beta) * y[index]; } // clang-format off #define instantiate_axpby(type_name, type) \ instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \ instantiate_kernel( \ "axpby_contiguous_" #type_name, axpby_contiguous, type) instantiate_axpby(float32, float); instantiate_axpby(float16, half); instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(complex64, complex64_t); // clang-format on ================================================ FILE: examples/extensions/bindings.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "axpby/axpby.h" namespace nb = nanobind; using namespace nb::literals; NB_MODULE(_ext, m) { m.doc() = "Sample extension for MLX"; m.def( "axpby", &my_ext::axpby, "x"_a, "y"_a, "alpha"_a, "beta"_a, nb::kw_only(), "stream"_a = nb::none(), R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed Args: x (array): Input array. y (array): Input array. alpha (float): Scaling factor for ``x``. beta (float): Scaling factor for ``y``. Returns: array: ``alpha * x + beta * y`` )"); } ================================================ FILE: examples/extensions/mlx_sample_extensions/__init__.py ================================================ # Copyright © 2023 Apple Inc. import mlx.core as mx from ._ext import axpby ================================================ FILE: examples/extensions/pyproject.toml ================================================ [build-system] requires = [ "setuptools>=42", "cmake>=3.25", "mlx>=0.18.0", "nanobind==2.10.2", ] build-backend = "setuptools.build_meta" ================================================ FILE: examples/extensions/requirements.txt ================================================ setuptools>=42 cmake>=3.25 mlx>=0.21.0 nanobind==2.10.2 ================================================ FILE: examples/extensions/setup.py ================================================ # Copyright © 2023-2024 Apple Inc. from setuptools import setup from mlx import extension if __name__ == "__main__": setup( name="mlx_sample_extensions", version="0.0.0", description="Sample C++ and Metal extensions for MLX primitives.", ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], cmdclass={"build_ext": extension.CMakeBuild}, packages=["mlx_sample_extensions"], package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, zip_safe=False, python_requires=">=3.8", ) ================================================ FILE: examples/extensions/test.py ================================================ import mlx.core as mx from mlx_sample_extensions import axpby a = mx.ones((3, 4)) b = mx.ones((3, 4)) c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu) print(f"c shape: {c_cpu.shape}") print(f"c dtype: {c_cpu.dtype}") print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}") print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}") ================================================ FILE: examples/python/linear_regression.py ================================================ # Copyright © 2023 Apple Inc. import time import mlx.core as mx num_features = 100 num_examples = 1_000 num_iters = 10_000 lr = 0.01 # True parameters w_star = mx.random.normal((num_features,)) # Input examples (design matrix) X = mx.random.normal((num_examples, num_features)) # Noisy labels eps = 1e-2 * mx.random.normal((num_examples,)) y = X @ w_star + eps # Initialize random parameters w = 1e-2 * mx.random.normal((num_features,)) def loss_fn(w): return 0.5 * mx.mean(mx.square(X @ w - y)) grad_fn = mx.grad(loss_fn) tic = time.perf_counter() for _ in range(num_iters): grad = grad_fn(w) w = w - lr * grad mx.eval(w) toc = time.perf_counter() loss = loss_fn(w) error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 throughput = num_iters / (toc - tic) print( f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, " f"Throughput {throughput:.5f} (it/s)" ) ================================================ FILE: examples/python/logistic_regression.py ================================================ # Copyright © 2023 Apple Inc. import time import mlx.core as mx num_features = 100 num_examples = 1_000 num_iters = 10_000 lr = 0.1 # True parameters w_star = mx.random.normal((num_features,)) # Input examples X = mx.random.normal((num_examples, num_features)) # Labels y = (X @ w_star) > 0 # Initialize random parameters w = 1e-2 * mx.random.normal((num_features,)) def loss_fn(w): logits = X @ w return mx.mean(mx.logaddexp(0.0, logits) - y * logits) grad_fn = mx.grad(loss_fn) tic = time.perf_counter() for _ in range(num_iters): grad = grad_fn(w) w = w - lr * grad mx.eval(w) toc = time.perf_counter() loss = loss_fn(w) final_preds = (X @ w) > 0 acc = mx.mean(final_preds == y) throughput = num_iters / (toc - tic) print( f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} " f"Throughput {throughput:.5f} (it/s)" ) ================================================ FILE: examples/python/qqmm.py ================================================ from itertools import product import mlx.core as mx # In mxfp8 mode, the results do not match exactly: # fewer than 1% of output elements differ. # This does not appear to be a systematic error. # The error can exceed 1 ULP for very small values, # and is always below 1 ULP for larger values. # For nvfp4, the results match exactly. # therefore I suspect that the discrepancy comes from # the mxfp8 matmul implementation in cuBLASLt.. def ulp_bf16_at(x): ax = mx.abs(x) min_normal = mx.array(2.0**-126) ax = mx.where(ax < min_normal, min_normal, ax) e = mx.floor(mx.log2(ax)) return mx.power(2.0, e - 7.0) def test_qqmm(): key = mx.random.key(0) k1, k2 = mx.random.split(key) dtypes = [mx.bfloat16, mx.float32, mx.float16] tests = ( (16, "nvfp4", 4), (32, "mxfp8", 8), ) shapes = ( [64, 65, 33, 128, 256, 1024, 1024 * 8], # M [64, 128, 256, 1024, 1024 * 8], # N [64, 128, 256, 1024, 1024 * 8], # K ) for group_size, mode, bits in tests: for M, N, K in product(*shapes): for dtype in dtypes: x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) w_dq = mx.dequantize( w_q, scales_w, group_size=group_size, bits=bits, mode=mode, dtype=dtype, ) y_q = mx.qqmm( x, w_q, scales_w, group_size=group_size, bits=bits, mode=mode, ) x_q, scales_x = mx.quantize( x, group_size=group_size, bits=bits, mode=mode ) x_dq = mx.dequantize( x_q, scales_x, group_size=group_size, bits=bits, mode=mode, dtype=dtype, ) y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) ulp = ulp_bf16_at(y_hat) error = (y_q - y_hat).abs() if not (mx.logical_or(error < 1e-3, error <= ulp).all()): raise AssertionError( f"qqmm test failed for shape {(M, N, K)}, " f"group_size={group_size}, bits={bits}, " f"mode={mode}, dtype={dtype}" ) def test_qqmm_vjp(): key = mx.random.key(0) k1, k2 = mx.random.split(key) M = 64 N = 1024 K = 512 tests = ( (16, "nvfp4", 4), (32, "mxfp8", 8), ) x = mx.random.normal(shape=(M, K), key=k1) c = mx.ones(shape=(M, N)) for group_size, mode, bits in tests: w = mx.random.normal(shape=(N, K), key=k2) def fn(x): return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode) _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) w_tq, scales_wt = mx.quantize( mx.transpose(w), group_size=group_size, bits=bits, mode=mode ) expected_out = mx.qqmm( c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode ) ulp = ulp_bf16_at(expected_out) error = (vjp_out[0] - expected_out).abs() if not (mx.logical_or(error < 1e-3, error <= ulp).all()): raise AssertionError( f"qqmm vjp test failed for shape {(M, N, K)}, " f"group_size={group_size}, bits={bits}, mode={mode}" ) if __name__ == "__main__": test_qqmm() test_qqmm_vjp() ================================================ FILE: mlx/3rdparty/.clang-format ================================================ DisableFormat: true SortIncludes: Never ================================================ FILE: mlx/3rdparty/pocketfft.h ================================================ /* This file is part of pocketfft. Copyright (C) 2010-2022 Max-Planck-Society Copyright (C) 2019-2020 Peter Bell For the odd-sized DCT-IV transforms: Copyright (C) 2003, 2007-14 Matteo Frigo Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology Authors: Martin Reinecke, Peter Bell All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #ifndef POCKETFFT_HDRONLY_H #define POCKETFFT_HDRONLY_H #ifndef __cplusplus #error This file is C++ and requires a C++ compiler. #endif #if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L) #error This file requires at least C++11 support. #endif #ifndef POCKETFFT_CACHE_SIZE #define POCKETFFT_CACHE_SIZE 0 #endif #include #include #include #include #include #include #include #if POCKETFFT_CACHE_SIZE!=0 #include #include #endif #ifndef POCKETFFT_NO_MULTITHREADING #include #include #include #include #include #include #include #ifdef POCKETFFT_PTHREADS # include #endif #endif #if defined(__GNUC__) #define POCKETFFT_NOINLINE __attribute__((noinline)) #define POCKETFFT_RESTRICT __restrict__ #elif defined(_MSC_VER) #define POCKETFFT_NOINLINE __declspec(noinline) #define POCKETFFT_RESTRICT __restrict #else #define POCKETFFT_NOINLINE #define POCKETFFT_RESTRICT #endif namespace pocketfft { namespace detail { using std::size_t; using std::ptrdiff_t; // Always use std:: for functions template T cos(T) = delete; template T sin(T) = delete; template T sqrt(T) = delete; using shape_t = std::vector; using stride_t = std::vector; constexpr bool FORWARD = true, BACKWARD = false; // only enable vector support for gcc>=5.0 and clang>=5.0 #ifndef POCKETFFT_NO_VECTORS #define POCKETFFT_NO_VECTORS #if defined(__INTEL_COMPILER) // do nothing. This is necessary because this compiler also sets __GNUC__. #elif defined(__clang__) // AppleClang has their own version numbering #ifdef __apple_build_version__ # if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) # undef POCKETFFT_NO_VECTORS # endif #elif __clang_major__ >= 5 # undef POCKETFFT_NO_VECTORS #endif #elif defined(__GNUC__) #if __GNUC__>=5 #undef POCKETFFT_NO_VECTORS #endif #endif #endif template struct VLEN { static constexpr size_t val=1; }; #ifndef POCKETFFT_NO_VECTORS #if (defined(__AVX512F__)) template<> struct VLEN { static constexpr size_t val=16; }; template<> struct VLEN { static constexpr size_t val=8; }; #elif (defined(__AVX__)) template<> struct VLEN { static constexpr size_t val=8; }; template<> struct VLEN { static constexpr size_t val=4; }; #elif (defined(__SSE2__)) template<> struct VLEN { static constexpr size_t val=4; }; template<> struct VLEN { static constexpr size_t val=2; }; #elif (defined(__VSX__)) template<> struct VLEN { static constexpr size_t val=4; }; template<> struct VLEN { static constexpr size_t val=2; }; #elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) template<> struct VLEN { static constexpr size_t val=4; }; template<> struct VLEN { static constexpr size_t val=2; }; #else #define POCKETFFT_NO_VECTORS #endif #endif // the __MINGW32__ part in the conditional below works around the problem that // the standard C++ library on Windows does not provide aligned_alloc() even // though the MinGW compiler and MSVC may advertise C++17 compliance. #if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) inline void *aligned_alloc(size_t align, size_t size) { // aligned_alloc() requires that the requested size is a multiple of "align" void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); if (!ptr) throw std::bad_alloc(); return ptr; } inline void aligned_dealloc(void *ptr) { free(ptr); } #else // portable emulation inline void *aligned_alloc(size_t align, size_t size) { align = std::max(align, alignof(max_align_t)); void *ptr = malloc(size+align); if (!ptr) throw std::bad_alloc(); void *res = reinterpret_cast ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); (reinterpret_cast(res))[-1] = ptr; return res; } inline void aligned_dealloc(void *ptr) { if (ptr) free((reinterpret_cast(ptr))[-1]); } #endif template class arr { private: T *p; size_t sz; #if defined(POCKETFFT_NO_VECTORS) static T *ralloc(size_t num) { if (num==0) return nullptr; void *res = malloc(num*sizeof(T)); if (!res) throw std::bad_alloc(); return reinterpret_cast(res); } static void dealloc(T *ptr) { free(ptr); } #else static T *ralloc(size_t num) { if (num==0) return nullptr; void *ptr = aligned_alloc(64, num*sizeof(T)); return static_cast(ptr); } static void dealloc(T *ptr) { aligned_dealloc(ptr); } #endif public: arr() : p(0), sz(0) {} arr(size_t n) : p(ralloc(n)), sz(n) {} arr(arr &&other) : p(other.p), sz(other.sz) { other.p=nullptr; other.sz=0; } ~arr() { dealloc(p); } void resize(size_t n) { if (n==sz) return; dealloc(p); p = ralloc(n); sz = n; } T &operator[](size_t idx) { return p[idx]; } const T &operator[](size_t idx) const { return p[idx]; } T *data() { return p; } const T *data() const { return p; } size_t size() const { return sz; } }; template struct cmplx { T r, i; cmplx() {} cmplx(T r_, T i_) : r(r_), i(i_) {} void Set(T r_, T i_) { r=r_; i=i_; } void Set(T r_) { r=r_; i=T(0); } cmplx &operator+= (const cmplx &other) { r+=other.r; i+=other.i; return *this; } templatecmplx &operator*= (T2 other) { r*=other; i*=other; return *this; } templatecmplx &operator*= (const cmplx &other) { T tmp = r*other.r - i*other.i; i = r*other.i + i*other.r; r = tmp; return *this; } templatecmplx &operator+= (const cmplx &other) { r+=other.r; i+=other.i; return *this; } templatecmplx &operator-= (const cmplx &other) { r-=other.r; i-=other.i; return *this; } template auto operator* (const T2 &other) const -> cmplx { return {r*other, i*other}; } template auto operator+ (const cmplx &other) const -> cmplx { return {r+other.r, i+other.i}; } template auto operator- (const cmplx &other) const -> cmplx { return {r-other.r, i-other.i}; } template auto operator* (const cmplx &other) const -> cmplx { return {r*other.r-i*other.i, r*other.i + i*other.r}; } template auto special_mul (const cmplx &other) const -> cmplx { using Tres = cmplx; return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) : Tres(r*other.r-i*other.i, r*other.i+i*other.r); } }; template inline void PM(T &a, T &b, T c, T d) { a=c+d; b=c-d; } template inline void PMINPLACE(T &a, T &b) { T t = a; a+=b; b=t-b; } template inline void MPINPLACE(T &a, T &b) { T t = a; a-=b; b=t+b; } template cmplx conj(const cmplx &a) { return {a.r, -a.i}; } template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) { res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); } template void ROT90(cmplx &a) { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } template void ROTX90(cmplx &a) { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } // // twiddle factor section // template class sincos_2pibyn { private: using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; size_t N, mask, shift; arr> v1, v2; static cmplx calc(size_t x, size_t n, Thigh ang) { x<<=3; if (x<4*n) // first half { if (x<2*n) // first quadrant { if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); } else // second quadrant { x-=2*n; if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); } } else { x=8*n-x; if (x<2*n) // third quadrant { if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); } else // fourth quadrant { x-=2*n; if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); } } } public: POCKETFFT_NOINLINE sincos_2pibyn(size_t n) : N(n) { constexpr auto pi = 3.141592653589793238462643383279502884197L; Thigh ang = Thigh(0.25L*pi/n); size_t nval = (n+2)/2; shift = 1; while((size_t(1)< operator[](size_t idx) const { if (2*idx<=N) { auto x1=v1[idx&mask], x2=v2[idx>>shift]; return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); } idx = N-idx; auto x1=v1[idx&mask], x2=v2[idx>>shift]; return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); } }; struct util // hack to avoid duplicate symbols { static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) { size_t res=1; while ((n&1)==0) { res=2; n>>=1; } for (size_t x=3; x*x<=n; x+=2) while ((n%x)==0) { res=x; n/=x; } if (n>1) res=n; return res; } static POCKETFFT_NOINLINE double cost_guess (size_t n) { constexpr double lfp=1.1; // penalty for non-hardcoded larger factors size_t ni=n; double result=0.; while ((n&1)==0) { result+=2; n>>=1; } for (size_t x=3; x*x<=n; x+=2) while ((n%x)==0) { result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors n/=x; } if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); return result*double(ni); } /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) { if (n<=12) return n; size_t bestfac=2*n; for (size_t f11=1; f11n) { if (x>=1; } else return n; } } return bestfac; } /* returns the smallest composite of 2, 3, 5 which is >= n */ static POCKETFFT_NOINLINE size_t good_size_real(size_t n) { if (n<=6) return n; size_t bestfac=2*n; for (size_t f5=1; f5n) { if (x>=1; } else return n; } } return bestfac; } static size_t prod(const shape_t &shape) { size_t res=1; for (auto sz: shape) res*=sz; return res; } static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, bool inplace) { auto ndim = shape.size(); if (ndim<1) throw std::runtime_error("ndim must be >= 1"); if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) throw std::runtime_error("stride dimension mismatch"); if (inplace && (stride_in!=stride_out)) throw std::runtime_error("stride mismatch"); } static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, bool inplace, const shape_t &axes) { sanity_check(shape, stride_in, stride_out, inplace); auto ndim = shape.size(); shape_t tmp(ndim,0); for (auto ax : axes) { if (ax>=ndim) throw std::invalid_argument("bad axis number"); if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); } } static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, bool inplace, size_t axis) { sanity_check(shape, stride_in, stride_out, inplace); if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); } #ifdef POCKETFFT_NO_MULTITHREADING static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, size_t /*axis*/, size_t /*vlen*/) { return 1; } #else static size_t thread_count (size_t nthreads, const shape_t &shape, size_t axis, size_t vlen) { if (nthreads==1) return 1; size_t size = prod(shape); size_t parallel = size / (shape[axis] * vlen); if (shape[axis] < 1000) parallel /= 4; size_t max_threads = nthreads == 0 ? std::thread::hardware_concurrency() : nthreads; return std::max(size_t(1), std::min(parallel, max_threads)); } #endif }; namespace threading { #ifdef POCKETFFT_NO_MULTITHREADING constexpr inline size_t thread_id() { return 0; } constexpr inline size_t num_threads() { return 1; } template void thread_map(size_t /* nthreads */, Func f) { f(); } #else inline size_t &thread_id() { static thread_local size_t thread_id_=0; return thread_id_; } inline size_t &num_threads() { static thread_local size_t num_threads_=1; return num_threads_; } static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); class latch { std::atomic num_left_; std::mutex mut_; std::condition_variable completed_; using lock_t = std::unique_lock; public: latch(size_t n): num_left_(n) {} void count_down() { lock_t lock(mut_); if (--num_left_) return; completed_.notify_all(); } void wait() { lock_t lock(mut_); completed_.wait(lock, [this]{ return is_ready(); }); } bool is_ready() { return num_left_ == 0; } }; template class concurrent_queue { std::queue q_; std::mutex mut_; std::atomic size_; using lock_t = std::lock_guard; public: void push(T val) { lock_t lock(mut_); ++size_; q_.push(std::move(val)); } bool try_pop(T &val) { if (size_ == 0) return false; lock_t lock(mut_); // Queue might have been emptied while we acquired the lock if (q_.empty()) return false; val = std::move(q_.front()); --size_; q_.pop(); return true; } bool empty() const { return size_==0; } }; // C++ allocator with support for over-aligned types template struct aligned_allocator { using value_type = T; template aligned_allocator(const aligned_allocator&) {} aligned_allocator() = default; T *allocate(size_t n) { void* mem = aligned_alloc(alignof(T), n*sizeof(T)); return static_cast(mem); } void deallocate(T *p, size_t /*n*/) { aligned_dealloc(p); } }; class thread_pool { // A reasonable guess, probably close enough for most hardware static constexpr size_t cache_line_size = 64; struct alignas(cache_line_size) worker { std::thread thread; std::condition_variable work_ready; std::mutex mut; std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; std::function work; void worker_main( std::atomic &shutdown_flag, std::atomic &unscheduled_tasks, concurrent_queue> &overflow_work) { using lock_t = std::unique_lock; bool expect_work = true; while (!shutdown_flag || expect_work) { std::function local_work; if (expect_work || unscheduled_tasks == 0) { lock_t lock(mut); // Wait until there is work to be executed work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); local_work.swap(work); expect_work = false; } bool marked_busy = false; if (local_work) { marked_busy = true; local_work(); } if (!overflow_work.empty()) { if (!marked_busy && busy_flag.test_and_set()) { expect_work = true; continue; } marked_busy = true; while (overflow_work.try_pop(local_work)) { --unscheduled_tasks; local_work(); } } if (marked_busy) busy_flag.clear(); } } }; concurrent_queue> overflow_work_; std::mutex mut_; std::vector> workers_; std::atomic shutdown_; std::atomic unscheduled_tasks_; using lock_t = std::lock_guard; void create_threads() { lock_t lock(mut_); size_t nthreads=workers_.size(); for (size_t i=0; ibusy_flag.clear(); worker->work = nullptr; worker->thread = std::thread([worker, this] { worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); }); } catch (...) { shutdown_locked(); throw; } } } void shutdown_locked() { shutdown_ = true; for (auto &worker : workers_) worker.work_ready.notify_all(); for (auto &worker : workers_) if (worker.thread.joinable()) worker.thread.join(); } public: explicit thread_pool(size_t nthreads): workers_(nthreads) { create_threads(); } thread_pool(): thread_pool(max_threads) {} ~thread_pool() { shutdown(); } void submit(std::function work) { lock_t lock(mut_); if (shutdown_) throw std::runtime_error("Work item submitted after shutdown"); ++unscheduled_tasks_; // First check for any idle workers and wake those for (auto &worker : workers_) if (!worker.busy_flag.test_and_set()) { --unscheduled_tasks_; { lock_t lock(worker.mut); worker.work = std::move(work); } worker.work_ready.notify_one(); return; } // If no workers were idle, push onto the overflow queue for later overflow_work_.push(std::move(work)); } void shutdown() { lock_t lock(mut_); shutdown_locked(); } void restart() { shutdown_ = false; create_threads(); } }; inline thread_pool & get_pool() { static thread_pool pool; #ifdef POCKETFFT_PTHREADS static std::once_flag f; std::call_once(f, []{ pthread_atfork( +[]{ get_pool().shutdown(); }, // prepare +[]{ get_pool().restart(); }, // parent +[]{ get_pool().restart(); } // child ); }); #endif return pool; } /** Map a function f over nthreads */ template void thread_map(size_t nthreads, Func f) { if (nthreads == 0) nthreads = max_threads; if (nthreads == 1) { f(); return; } auto & pool = get_pool(); latch counter(nthreads); std::exception_ptr ex; std::mutex ex_mut; for (size_t i=0; i lock(ex_mut); ex = std::current_exception(); } counter.count_down(); }); } counter.wait(); if (ex) std::rethrow_exception(ex); } #endif } // // complex FFTPACK transforms // template class cfftp { private: struct fctdata { size_t fct; cmplx *tw, *tws; }; size_t length; arr> mem; std::vector fact; void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); } template void pass2 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+2*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); } } } #define POCKETFFT_PREP3(idx) \ T t0 = CC(idx,0,k), t1, t2; \ PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ CH(idx,k,0)=t0+t1; #define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ { \ T ca=t0+t1*twr; \ T cb{-t2.i*twi, t2.r*twi}; \ PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ } #define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ { \ T ca=t0+t1*twr; \ T cb{-t2.i*twi, t2.r*twi}; \ special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ } template void pass3 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r=-0.5, tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+3*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k void pass4 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+4*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k(t4); PM(CH(0,k,0),CH(0,k,2),t2,t3); PM(CH(0,k,1),CH(0,k,3),t1,t4); } else for (size_t k=0; k(t4); PM(CH(0,k,0),CH(0,k,2),t2,t3); PM(CH(0,k,1),CH(0,k,3),t1,t4); } for (size_t i=1; i(t4); CH(i,k,0) = t2+t3; special_mul(t1+t4,WA(0,i),CH(i,k,1)); special_mul(t2-t3,WA(1,i),CH(i,k,2)); special_mul(t1-t4,WA(2,i),CH(i,k,3)); } } } #define POCKETFFT_PREP5(idx) \ T t0 = CC(idx,0,k), t1, t2, t3, t4; \ PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ CH(idx,k,0).r=t0.r+t1.r+t2.r; \ CH(idx,k,0).i=t0.i+t1.i+t2.i; #define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ { \ T ca,cb; \ ca.r=t0.r+twar*t1.r+twbr*t2.r; \ ca.i=t0.i+twar*t1.i+twbr*t2.i; \ cb.i=twai*t4.r twbi*t3.r; \ cb.r=-(twai*t4.i twbi*t3.i); \ PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ } #define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ { \ T ca,cb,da,db; \ ca.r=t0.r+twar*t1.r+twbr*t2.r; \ ca.i=t0.i+twar*t1.i+twbr*t2.i; \ cb.i=twai*t4.r twbi*t3.r; \ cb.r=-(twai*t4.i twbi*t3.i); \ special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ } template void pass5 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), tw2r= T0(-0.8090169943749474241022934171828191L), tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+5*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ } template void pass7(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), tw2r= T0(-0.2225209339563144042889025644967948L), tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), tw3r= T0(-0.9009688679024191262361023195074451L), tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+7*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k void ROTX45(T &a) const { constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); if (fwd) { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } else { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } } template void ROTX135(T &a) const { constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); if (fwd) { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } else { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } } template void pass8 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+8*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k(a3); ROTX90(a7); PMINPLACE(a5,a7); ROTX45(a5); ROTX135(a7); PM(a0,a4,CC(0,0,k),CC(0,4,k)); PM(a2,a6,CC(0,2,k),CC(0,6,k)); PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); ROTX90(a6); PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); } else for (size_t k=0; k(a3); ROTX90(a7); PMINPLACE(a5,a7); ROTX45(a5); ROTX135(a7); PM(a0,a4,CC(0,0,k),CC(0,4,k)); PM(a2,a6,CC(0,2,k),CC(0,6,k)); PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); ROTX90(a6); PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); } for (size_t i=1; i(a7); PMINPLACE(a1,a3); ROTX90(a3); PMINPLACE(a5,a7); ROTX45(a5); ROTX135(a7); PM(a0,a4,CC(i,0,k),CC(i,4,k)); PM(a2,a6,CC(i,2,k),CC(i,6,k)); PMINPLACE(a0,a2); CH(i,k,0) = a0+a1; special_mul(a0-a1,WA(3,i),CH(i,k,4)); special_mul(a2+a3,WA(1,i),CH(i,k,2)); special_mul(a2-a3,WA(5,i),CH(i,k,6)); ROTX90(a6); PMINPLACE(a4,a6); special_mul(a4+a5,WA(0,i),CH(i,k,1)); special_mul(a4-a5,WA(4,i),CH(i,k,5)); special_mul(a6+a7,WA(2,i),CH(i,k,3)); special_mul(a6-a7,WA(6,i),CH(i,k,7)); } } } #define POCKETFFT_PREP11(idx) \ T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; #define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ { \ T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ cb; \ cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ PM(out1,out2,ca,cb); \ } #define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) #define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ { \ T da,db; \ POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ } template void pass11 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), tw2r= T0(0.4154150130018864255292741492296232L), tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), tw3r= T0(-0.1423148382732851404437926686163697L), tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), tw4r= T0(-0.6548607339452850640569250724662936L), tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), tw5r= T0(-0.9594929736144973898903680570663277L), tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+11*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; if (ido==1) for (size_t k=0; k void passg (size_t ido, size_t ip, size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const cmplx * POCKETFFT_RESTRICT wa, const cmplx * POCKETFFT_RESTRICT csarr) const { const size_t cdim=ip; size_t ipph = (ip+1)/2; size_t idl1 = ido*l1; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& { return cc[a+ido*(b+l1*c)]; }; auto CX2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a+idl1*b]; }; auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& { return ch[a+idl1*b]; }; arr> wal(ip); wal[0] = cmplx(1., 0.); for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); for (size_t k=0; kip) iwal-=ip; cmplx xwal=wal[iwal]; iwal+=l; if (iwal>ip) iwal-=ip; cmplx xwal2=wal[iwal]; for (size_t ik=0; ikip) iwal-=ip; cmplx xwal=wal[iwal]; for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); idij=(jc-1)*(ido-1)+i-1; special_mul(x2,wa[idij],CX(i,k,jc)); } } } } template void pass_all(T c[], T0 fct) const { if (length==1) { c[0]*=fct; return; } size_t l1=1; arr ch(length); T *p1=c, *p2=ch.data(); for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); else if(ip==8) pass8(ido, l1, p1, p2, fact[k1].tw); else if(ip==2) pass2(ido, l1, p1, p2, fact[k1].tw); else if(ip==3) pass3 (ido, l1, p1, p2, fact[k1].tw); else if(ip==5) pass5 (ido, l1, p1, p2, fact[k1].tw); else if(ip==7) pass7 (ido, l1, p1, p2, fact[k1].tw); else if(ip==11) pass11 (ido, l1, p1, p2, fact[k1].tw); else { passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); std::swap(p1,p2); } std::swap(p1,p2); l1=l2; } if (p1!=c) { if (fct!=1.) for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const { fwd ? pass_all(c, fct) : pass_all(c, fct); } private: POCKETFFT_NOINLINE void factorize() { size_t len=length; while ((len&7)==0) { add_factor(8); len>>=3; } while ((len&3)==0) { add_factor(4); len>>=2; } if ((len&1)==0) { len>>=1; // factor 2 should be at the front of the factor list add_factor(2); std::swap(fact[0].fct, fact.back().fct); } for (size_t divisor=3; divisor*divisor<=len; divisor+=2) while ((len%divisor)==0) { add_factor(divisor); len/=divisor; } if (len>1) add_factor(len); } size_t twsize() const { size_t twsize=0, l1=1; for (size_t k=0; k11) twsize+=ip; l1*=ip; } return twsize; } void comp_twiddle() { sincos_2pibyn twiddle(length); size_t l1=1; size_t memofs=0; for (size_t k=0; k11) { fact[k].tws=mem.data()+memofs; memofs+=ip; for (size_t j=0; j class rfftp { private: struct fctdata { size_t fct; T0 *tw, *tws; }; size_t length; arr mem; std::vector fact; void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); } /* (a+ib) = conj(c+id) * (e+if) */ template inline void MULPM (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const { a=c*e+d*f; b=c*f-d*e; } template void radf2 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+2*c)]; }; for (size_t k=0; k void radf3(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+3*c)]; }; for (size_t k=0; k void radf4(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+4*c)]; }; for (size_t k=0; k void radf5(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), ti11= T0(0.9510565162951535721164393333793821L), tr12= T0(-0.8090169943749474241022934171828191L), ti12= T0(0.5877852522924731291687059546390728L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+5*c)]; }; for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const { const size_t cdim=ip; size_t ipph=(ip+1)/2; size_t idl1 = ido*l1; auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& { return ch[a+ido*(b+l1*c)]; }; auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& { return cc[a+ido*(b+l1*c)]; }; auto C2 = [cc,idl1] (size_t a, size_t b) -> T& { return cc[a+idl1*b]; }; auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& { return ch[a+idl1*b]; }; if (ido>1) { for (size_t j=1, jc=ip-1; j=ip) iang-=ip; T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; iang+=l; if (iang>=ip) iang-=ip; T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; iang+=l; if (iang>=ip) iang-=ip; T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; iang+=l; if (iang>=ip) iang-=ip; T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; for (size_t ik=0; ik=ip) iang-=ip; T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; iang+=l; if (iang>=ip) iang-=ip; T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; for (size_t ik=0; ik=ip) iang-=ip; T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; for (size_t ik=0; ik void radb2(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+2*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; for (size_t k=0; k void radb3(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+3*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; for (size_t k=0; k void radb4(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+4*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; for (size_t k=0; k void radb5(size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa) const { constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), ti11= T0(0.9510565162951535721164393333793821L), tr12= T0(-0.8090169943749474241022934171828191L), ti12= T0(0.5877852522924731291687059546390728L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+5*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const { const size_t cdim=ip; size_t ipph=(ip+1)/ 2; size_t idl1 = ido*l1; auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; auto C2 = [cc,idl1](size_t a, size_t b) -> T& { return cc[a+idl1*b]; }; auto CH2 = [ch,idl1](size_t a, size_t b) -> T& { return ch[a+idl1*b]; }; for (size_t k=0; kip) iang-=ip; T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; iang+=l; if(iang>ip) iang-=ip; T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; iang+=l; if(iang>ip) iang-=ip; T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; iang+=l; if(iang>ip) iang-=ip; T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; for (size_t ik=0; ikip) iang-=ip; T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; iang+=l; if(iang>ip) iang-=ip; T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; for (size_t ik=0; ikip) iang-=ip; T0 war=csarr[2*iang], wai=csarr[2*iang+1]; for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const { if (p1!=c) { if (fct!=1.) for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const { if (length==1) { c[0]*=fct; return; } size_t nf=fact.size(); arr ch(length); T *p1=c, *p2=ch.data(); if (r2hc) for(size_t k1=0, l1=length; k1>=2; } if ((len%2)==0) { len>>=1; // factor 2 should be at the front of the factor list add_factor(2); std::swap(fact[0].fct, fact.back().fct); } for (size_t divisor=3; divisor*divisor<=len; divisor+=2) while ((len%divisor)==0) { add_factor(divisor); len/=divisor; } if (len>1) add_factor(len); } size_t twsize() const { size_t twsz=0, l1=1; for (size_t k=0; k5) twsz+=2*ip; l1*=ip; } return twsz; } void comp_twiddle() { sincos_2pibyn twid(length); size_t l1=1; T0 *ptr=mem.data(); for (size_t k=0; k5) // special factors required by *g functions { fact[k].tws=ptr; ptr+=2*ip; fact[k].tws[0] = 1.; fact[k].tws[1] = 0.; for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) { fact[k].tws[i ] = twid[i/2*(length/ip)].r; fact[k].tws[i+1] = twid[i/2*(length/ip)].i; fact[k].tws[ic] = twid[i/2*(length/ip)].r; fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; } } l1*=ip; } } public: POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_) { if (length==0) throw std::runtime_error("zero-length FFT requested"); if (length==1) return; factorize(); mem.resize(twsize()); comp_twiddle(); } }; // // complex Bluestein transforms // template class fftblue { private: size_t n, n2; cfftp plan; arr> mem; cmplx *bk, *bkf; template void fft(cmplx c[], T0 fct) const { arr> akf(n2); /* initialize a_k and FFT it */ for (size_t m=0; m(c[m],bk[m],akf[m]); auto zero = akf[0]*T0(0); for (size_t m=n; m(bkf[0]); for (size_t m=1; m<(n2+1)/2; ++m) { akf[m] = akf[m].template special_mul(bkf[m]); akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); } if ((n2&1)==0) akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); /* inverse FFT */ plan.exec (akf.data(),1.,false); /* multiply by b_k */ for (size_t m=0; m(bk[m])*fct; } public: POCKETFFT_NOINLINE fftblue(size_t length) : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), bk(mem.data()), bkf(mem.data()+n) { /* initialize b_k */ sincos_2pibyn tmp(2*n); bk[0].Set(1, 0); size_t coeff=0; for (size_t m=1; m=2*n) coeff-=2*n; bk[m] = tmp[coeff]; } /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ arr> tbkf(n2); T0 xn2 = T0(1)/T0(n2); tbkf[0] = bk[0]*xn2; for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const { fwd ? fft(c,fct) : fft(c,fct); } template void exec_r(T c[], T0 fct, bool fwd) { arr> tmp(n); if (fwd) { auto zero = T0(0)*c[0]; for (size_t m=0; m(tmp.data(),fct); c[0] = tmp[0].r; std::copy_n (&tmp[1].r, n-1, &c[1]); } else { tmp[0].Set(c[0],c[0]*0); std::copy_n (c+1, n-1, &tmp[1].r); if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; for (size_t m=1; 2*m(tmp.data(),fct); for (size_t m=0; m class pocketfft_c { private: std::unique_ptr> packplan; std::unique_ptr> blueplan; size_t len; public: POCKETFFT_NOINLINE pocketfft_c(size_t length) : len(length) { if (length==0) throw std::runtime_error("zero-length FFT requested"); size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); if (tmp*tmp <= length) { packplan=std::unique_ptr>(new cfftp(length)); return; } double comp1 = util::cost_guess(length); double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); comp2*=1.5; /* fudge factor that appears to give good overall performance */ if (comp2>(new fftblue(length)); else packplan=std::unique_ptr>(new cfftp(length)); } template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } size_t length() const { return len; } }; // // flexible (FFTPACK/Bluestein) real-valued 1D transform // template class pocketfft_r { private: std::unique_ptr> packplan; std::unique_ptr> blueplan; size_t len; public: POCKETFFT_NOINLINE pocketfft_r(size_t length) : len(length) { if (length==0) throw std::runtime_error("zero-length FFT requested"); size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); if (tmp*tmp <= length) { packplan=std::unique_ptr>(new rfftp(length)); return; } double comp1 = 0.5*util::cost_guess(length); double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); comp2*=1.5; /* fudge factor that appears to give good overall performance */ if (comp2>(new fftblue(length)); else packplan=std::unique_ptr>(new rfftp(length)); } template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } size_t length() const { return len; } }; // // sine/cosine transforms // template class T_dct1 { private: pocketfft_r fftplan; public: POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; if (ortho) { c[0]*=sqrt2; c[n-1]*=sqrt2; } arr tmp(N); tmp[0] = c[0]; for (size_t i=1; i class T_dst1 { private: pocketfft_r fftplan; public: POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2*(length+1)) {} template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool /*cosine*/) const { size_t N=fftplan.length(), n=N/2-1; arr tmp(N); tmp[0] = tmp[n+1] = c[0]*0; for (size_t i=0; i class T_dcst23 { private: pocketfft_r fftplan; std::vector twiddle; public: POCKETFFT_NOINLINE T_dcst23(size_t length) : fftplan(length), twiddle(length) { sincos_2pibyn tw(4*length); for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int type, bool cosine) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); size_t NS2 = (N+1)/2; if (type==2) { if (!cosine) for (size_t k=1; k class T_dcst4 { private: size_t N; std::unique_ptr> fft; std::unique_ptr> rfft; arr> C2; public: POCKETFFT_NOINLINE T_dcst4(size_t length) : N(length), fft((N&1) ? nullptr : new pocketfft_c(N/2)), rfft((N&1)? new pocketfft_r(N) : nullptr), C2((N&1) ? 0 : N/2) { if ((N&1)==0) { sincos_2pibyn tw(16*N); for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool cosine) const { size_t n2 = N/2; if (!cosine) for (size_t k=0, kc=N-1; k y(N); { size_t i=0, m=n2; for (; mexec(y.data(), fct, true); { auto SGN = [](size_t i) { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); return (i&2) ? -sqrt2 : sqrt2; }; c[n2] = y[0]*SGN(n2+1); size_t i=0, i1=1, k=1; for (; k> y(n2); for(size_t i=0; iexec(y.data(), fct, true); for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) { #if POCKETFFT_CACHE_SIZE==0 return std::make_shared(length); #else constexpr size_t nmax=POCKETFFT_CACHE_SIZE; static std::array, nmax> cache; static std::array last_access{{0}}; static size_t access_counter = 0; static std::mutex mut; auto find_in_cache = [&]() -> std::shared_ptr { for (size_t i=0; ilength()==length)) { // no need to update if this is already the most recent entry if (last_access[i]!=access_counter) { last_access[i] = ++access_counter; // Guard against overflow if (access_counter == 0) last_access.fill(0); } return cache[i]; } return nullptr; }; { std::lock_guard lock(mut); auto p = find_in_cache(); if (p) return p; } auto plan = std::make_shared(length); { std::lock_guard lock(mut); auto p = find_in_cache(); if (p) return p; size_t lru = 0; for (size_t i=1; i class cndarr: public arr_info { protected: const char *d; public: cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) : arr_info(shape_, stride_), d(reinterpret_cast(data_)) {} const T &operator[](ptrdiff_t ofs) const { return *reinterpret_cast(d+ofs); } }; template class ndarr: public cndarr { public: ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) : cndarr::cndarr(const_cast(data_), shape_, stride_) {} T &operator[](ptrdiff_t ofs) { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } }; template class multi_iter { private: shape_t pos; const arr_info &iarr, &oarr; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; size_t idim, rem; void advance_i() { for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); if (i==idim) continue; p_ii += iarr.stride(i); p_oi += oarr.stride(i); if (++pos[i] < iarr.shape(i)) return; pos[i] = 0; p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); } } public: multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), idim(idim_), rem(iarr.size()/iarr.shape(idim)) { auto nshares = threading::num_threads(); if (nshares==1) return; if (nshares==0) throw std::runtime_error("can't run with zero threads"); auto myshare = threading::thread_id(); if (myshare>=nshares) throw std::runtime_error("impossible share requested"); size_t nbase = rem/nshares; size_t additional = rem%nshares; size_t lo = myshare*nbase + ((myshare=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (++pos[i] < arr.shape(i)) return; pos[i] = 0; p -= ptrdiff_t(arr.shape(i))*arr.stride(i); } } ptrdiff_t ofs() const { return p; } size_t remaining() const { return rem; } }; class rev_iter { private: shape_t pos; const arr_info &arr; std::vector rev_axis; std::vector rev_jump; size_t last_axis, last_size; shape_t shp; ptrdiff_t p, rp; size_t rem; public: rev_iter(const arr_info &arr_, const shape_t &axes) : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), rev_jump(arr_.ndim(), 1), p(0), rp(0) { for (auto ax: axes) rev_axis[ax]=1; last_axis = axes.back(); last_size = arr.shape(last_axis)/2 + 1; shp = arr.shape(); shp[last_axis] = last_size; rem=1; for (auto i: shp) rem *= i; } void advance() { --rem; for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (!rev_axis[i]) rp += arr.stride(i); else { rp -= arr.stride(i); if (rev_jump[i]) { rp += ptrdiff_t(arr.shape(i))*arr.stride(i); rev_jump[i] = 0; } } if (++pos[i] < shp[i]) return; pos[i] = 0; p -= ptrdiff_t(shp[i])*arr.stride(i); if (rev_axis[i]) { rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); rev_jump[i] = 1; } else rp -= ptrdiff_t(shp[i])*arr.stride(i); } } ptrdiff_t ofs() const { return p; } ptrdiff_t rev_ofs() const { return rp; } size_t remaining() const { return rem; } }; template struct VTYPE {}; template using vtype_t = typename VTYPE::type; #ifndef POCKETFFT_NO_VECTORS template<> struct VTYPE { using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); }; template<> struct VTYPE { using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); }; template<> struct VTYPE { using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); }; #endif template arr alloc_tmp(const shape_t &shape, size_t axsize, size_t elemsize) { auto othersize = util::prod(shape)/axsize; auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); return arr(tmpsize*elemsize); } template arr alloc_tmp(const shape_t &shape, const shape_t &axes, size_t elemsize) { size_t fullsize=util::prod(shape); size_t tmpsize=0; for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); if (sz>tmpsize) tmpsize=sz; } return arr(tmpsize*elemsize); } template void copy_input(const multi_iter &it, const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) { for (size_t i=0; i void copy_input(const multi_iter &it, const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) { for (size_t i=0; i void copy_input(const multi_iter &it, const cndarr &src, T *POCKETFFT_RESTRICT dst) { if (dst == &src[it.iofs(0)]) return; // in-place for (size_t i=0; i void copy_output(const multi_iter &it, const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) { for (size_t i=0; i void copy_output(const multi_iter &it, const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) { for (size_t i=0; i void copy_output(const multi_iter &it, const T *POCKETFFT_RESTRICT src, ndarr &dst) { if (src == &dst[it.oofs(0)]) return; // in-place for (size_t i=0; i struct add_vec { using type = vtype_t; }; template struct add_vec> { using type = cmplx>; }; template using add_vec_t = typename add_vec::type; template POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, const bool allow_inplace=true) { std::shared_ptr plan; for (size_t iax=0; iaxlength())) plan = get_plan(len); threading::thread_map( util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); const auto &tin(iax==0? in : out); multi_iter it(tin, out, axes[iax]); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); exec(it, tin, out, tdatav, *plan, fct); } #endif while (it.remaining()>0) { it.advance(1); auto buf = allow_inplace && it.stride_out() == sizeof(T) ? &out[it.oofs(0)] : reinterpret_cast(storage.data()); exec(it, tin, out, buf, *plan, fct); } }); // end of parallel region fct = T0(1); // factor has been applied, use 1 for remaining axes } } struct ExecC2C { bool forward; template void operator () ( const multi_iter &it, const cndarr> &in, ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, forward); copy_output(it, buf, out); } }; template void copy_hartley(const multi_iter &it, const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) { for (size_t j=0; j void copy_hartley(const multi_iter &it, const T *POCKETFFT_RESTRICT src, ndarr &dst) { dst[it.oofs(0)] = src[0]; size_t i=1, i1=1, i2=it.length_out()-1; for (i=1; i void operator () ( const multi_iter &it, const cndarr &in, ndarr &out, T * buf, const pocketfft_r &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, true); copy_hartley(it, buf, out); } }; struct ExecDcst { bool ortho; int type; bool cosine; template void operator () (const multi_iter &it, const cndarr &in, ndarr &out, T * buf, const Tplan &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, ortho, type, cosine); copy_output(it, buf, out); } }; template POCKETFFT_NOINLINE void general_r2c( const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(in.shape(axis)); size_t len=in.shape(axis); threading::thread_map( util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); multi_iter it(in, out, axis); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); copy_input(it, in, tdatav); plan->exec(tdatav, fct, true); for (size_t j=0; j0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); copy_input(it, in, tdata); plan->exec(tdata, fct, true); out[it.oofs(0)].Set(tdata[0]); size_t i=1, ii=1; if (forward) for (; i POCKETFFT_NOINLINE void general_c2r( const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(out.shape(axis)); size_t len=out.shape(axis); threading::thread_map( util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(out.shape(), len, sizeof(T)); multi_iter it(in, out, axis); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); for (size_t j=0; jexec(tdatav, fct, false); copy_output(it, tdatav, out); } #endif while (it.remaining()>0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); tdata[0]=in[it.iofs(0)].r; { size_t i=1, ii=1; if (forward) for (; iexec(tdata, fct, false); copy_output(it, tdata, out); } }); // end of parallel region } struct ExecR2R { bool r2h, forward; template void operator () ( const multi_iter &it, const cndarr &in, ndarr &out, T * buf, const pocketfft_r &plan, T0 fct) const { copy_input(it, in, buf); if ((!r2h) && forward) for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const std::complex *data_in, std::complex *data_out, T fct, size_t nthreads=1) { if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr> ain(data_in, shape, stride_in); ndarr> aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); } template void dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); const ExecDcst exec{ortho, type, true}; if (type==1) general_nd>(ain, aout, axes, fct, nthreads, exec); else if (type==4) general_nd>(ain, aout, axes, fct, nthreads, exec); else general_nd>(ain, aout, axes, fct, nthreads, exec); } template void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); const ExecDcst exec{ortho, type, false}; if (type==1) general_nd>(ain, aout, axes, fct, nthreads, exec); else if (type==4) general_nd>(ain, aout, axes, fct, nthreads, exec); else general_nd>(ain, aout, axes, fct, nthreads, exec); } template void r2c(const shape_t &shape_in, const stride_t &stride_in, const stride_t &stride_out, size_t axis, bool forward, const T *data_in, std::complex *data_out, T fct, size_t nthreads=1) { if (util::prod(shape_in)==0) return; util::sanity_check(shape_in, stride_in, stride_out, false, axis); cndarr ain(data_in, shape_in, stride_in); shape_t shape_out(shape_in); shape_out[axis] = shape_in[axis]/2 + 1; ndarr> aout(data_out, shape_out, stride_out); general_r2c(ain, aout, axis, forward, fct, nthreads); } template void r2c(const shape_t &shape_in, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const T *data_in, std::complex *data_out, T fct, size_t nthreads=1) { if (util::prod(shape_in)==0) return; util::sanity_check(shape_in, stride_in, stride_out, false, axes); r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, fct, nthreads); if (axes.size()==1) return; shape_t shape_out(shape_in); shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, T(1), nthreads); } template void c2r(const shape_t &shape_out, const stride_t &stride_in, const stride_t &stride_out, size_t axis, bool forward, const std::complex *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape_out)==0) return; util::sanity_check(shape_out, stride_in, stride_out, false, axis); shape_t shape_in(shape_out); shape_in[axis] = shape_out[axis]/2 + 1; cndarr> ain(data_in, shape_in, stride_in); ndarr aout(data_out, shape_out, stride_out); general_c2r(ain, aout, axis, forward, fct, nthreads); } template void c2r(const shape_t &shape_out, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const std::complex *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape_out)==0) return; if (axes.size()==1) return c2r(shape_out, stride_in, stride_out, axes[0], forward, data_in, data_out, fct, nthreads); util::sanity_check(shape_out, stride_in, stride_out, false, axes); auto shape_in = shape_out; shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; auto nval = util::prod(shape_in); stride_t stride_inter(shape_in.size()); stride_inter.back() = sizeof(cmplx); for (int i=int(shape_in.size())-2; i>=0; --i) stride_inter[size_t(i)] = stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); arr> tmp(nval); auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), T(1), nthreads); c2r(shape_out, stride_inter, stride_out, axes.back(), forward, tmp.data(), data_out, fct, nthreads); } template void r2r_fftpack(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecR2R{real2hermitian, forward}); } template void r2r_separable_hartley(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const T *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, false); } template void r2r_genuine_hartley(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const T *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape)==0) return; if (axes.size()==1) return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, data_out, fct, nthreads); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); shape_t tshp(shape); tshp[axes.back()] = tshp[axes.back()]/2+1; arr> tdata(util::prod(tshp)); stride_t tstride(shape.size()); tstride.back()=sizeof(std::complex); for (size_t i=tstride.size()-1; i>0; --i) tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); cndarr> atmp(tdata.data(), tshp, tstride); ndarr aout(data_out, shape, stride_out); simple_iter iin(atmp); rev_iter iout(aout, axes); while(iin.remaining()>0) { auto v = atmp[iin.ofs()]; aout[iout.ofs()] = v.r+v.i; aout[iout.rev_ofs()] = v.r-v.i; iin.advance(); iout.advance(); } } } // namespace detail using detail::FORWARD; using detail::BACKWARD; using detail::shape_t; using detail::stride_t; using detail::c2c; using detail::c2r; using detail::r2c; using detail::r2r_fftpack; using detail::r2r_separable_hartley; using detail::r2r_genuine_hartley; using detail::dct; using detail::dst; } // namespace pocketfft #undef POCKETFFT_NOINLINE #undef POCKETFFT_RESTRICT #endif // POCKETFFT_HDRONLY_H ================================================ FILE: mlx/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) # Define MLX_VERSION only in the version.cpp file. add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_include_directories(mlx_version PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(mlx PRIVATE $) # Do not export symbols by default. set_target_properties( mlx mlx_version PROPERTIES VISIBILITY_INLINES_HIDDEN ON CXX_VISIBILITY_PRESET hidden CUDA_VISIBILITY_PRESET hidden) # Define MLX_EXPORT for shared libraries, MLX_STATIC for static libraries. set_target_properties(mlx PROPERTIES DEFINE_SYMBOL MLX_EXPORT) if(BUILD_SHARED_LIBS) target_compile_definitions(mlx_version PUBLIC MLX_EXPORT) else() target_compile_definitions(mlx PUBLIC MLX_STATIC) target_compile_definitions(mlx_version PUBLIC MLX_STATIC) endif() if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # Supress warnings: note: parameter passing for argument of type # 'std::pair' when C++17 is enabled changed to match C++14 in # GCC 10.1 target_compile_options(mlx PRIVATE -Wno-psabi) endif() if(MSVC) # Some of CUDA's headers include windows.h, which defines min/max macros. target_compile_definitions(mlx PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN) # Unicode support in fmt does not compile in .cu files. target_compile_definitions(mlx PRIVATE FMT_UNICODE=0) # Disable some MSVC warnings to speed up compilation. target_compile_options( mlx PUBLIC $<$:/wd4244 /wd4267> PRIVATE $<$:/wd4068 /wd4146 /wd4700 /wd4804 /wd4805> $<$:-Xcompiler=/wd4244 -Xcompiler=/wd4267>) # Enable /bigobj for heavily templated code (e.g., binary.cpp) that exceeds # the default 65,535 section limit in COFF object files. target_compile_options( mlx PRIVATE $<$:/bigobj> $<$:-Xcompiler=/bigobj>) # Use modern preprocessor, otherwise CCCL would complain. target_compile_options( mlx PRIVATE $<$:/Zc:preprocessor> $<$:-Xcompiler=/Zc:preprocessor>) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) if(MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) endif() if(MLX_BUILD_CUDA) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() ================================================ FILE: mlx/allocator.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/api.h" namespace mlx::core::allocator { // Simple wrapper around buffer pointers // WARNING: Only Buffer objects constructed from and those that wrap // raw pointers from mlx::allocator are supported. class MLX_API Buffer { private: void* ptr_; public: explicit Buffer(void* ptr) : ptr_(ptr) {}; // Get the raw data pointer from the buffer void* raw_ptr(); // Get the buffer pointer from the buffer const void* ptr() const { return ptr_; }; void* ptr() { return ptr_; }; }; class MLX_API Allocator { /** Abstract base class for a memory allocator. */ public: virtual Buffer malloc(size_t size) = 0; virtual void free(Buffer buffer) = 0; virtual size_t size(Buffer buffer) const = 0; virtual Buffer make_buffer(void* ptr, size_t size) { return Buffer{nullptr}; }; virtual void release(Buffer buffer) {} Allocator() = default; Allocator(const Allocator& other) = delete; Allocator(Allocator&& other) = delete; Allocator& operator=(const Allocator& other) = delete; Allocator& operator=(Allocator&& other) = delete; virtual ~Allocator() = default; }; MLX_API Allocator& allocator(); inline Buffer malloc(size_t size) { return allocator().malloc(size); } inline void free(Buffer buffer) { allocator().free(buffer); } // Make a Buffer from a raw pointer of the given size without a copy. If a // no-copy conversion is not possible then the returned buffer.ptr() will be // nullptr. Any buffer created with this function must be released with // release(buffer) inline Buffer make_buffer(void* ptr, size_t size) { return allocator().make_buffer(ptr, size); }; // Release a buffer from the allocator made with make_buffer inline void release(Buffer buffer) { allocator().release(buffer); } } // namespace mlx::core::allocator ================================================ FILE: mlx/api.h ================================================ // Copyright © 2024 Apple Inc. #pragma once // MLX_API macro for controlling symbol visibility, must add for public APIs. // // Usage: // MLX_API void some_function(...); // class MLX_API SomeClass { ... }; #if defined(MLX_STATIC) // Static library build - no import/export decorations needed #define MLX_API #else // Shared library build. #if defined(_WIN32) #if defined(MLX_EXPORT) #define MLX_API __declspec(dllexport) #else #define MLX_API __declspec(dllimport) #endif // defined(MLX_EXPORT) #else #define MLX_API __attribute__((visibility("default"))) #endif // defined(_WIN32) #endif // defined(MLX_STATIC) ================================================ FILE: mlx/array.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/array.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" namespace mlx::core { array::array(const std::complex& val, Dtype dtype /* = complex64 */) : array_desc_(std::make_shared(Shape{}, dtype)) { auto cval = static_cast(val); init(&cval); } array::array( Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs) : array_desc_( std::make_shared( std::move(shape), dtype, std::move(primitive), std::move(inputs))) { if (has_primitive() && this->primitive().stream().device == Device::gpu) { for (auto& in : this->inputs()) { if (in.dtype() == float64) { throw std::invalid_argument("float64 is not supported on the GPU"); } } if (this->dtype() == float64) { throw std::invalid_argument("float64 is not supported on the GPU"); } } } std::vector array::make_arrays( std::vector shapes, const std::vector& dtypes, const std::shared_ptr& primitive, const std::vector& inputs) { std::vector outputs; for (size_t i = 0; i < shapes.size(); ++i) { outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs); } // For each node in |outputs|, its siblings are the other nodes. for (size_t i = 0; i < outputs.size(); ++i) { auto siblings = outputs; siblings.erase(siblings.begin() + i); outputs[i].set_siblings(std::move(siblings), i); } return outputs; } array array::unsafe_weak_copy(const array& other) { auto cpy = array(other.shape(), other.dtype(), nullptr, {}); cpy.set_data( other.buffer(), other.data_size(), other.strides(), other.flags(), [](auto) {}); cpy.array_desc_->offset = other.array_desc_->offset; return cpy; } array::array(std::initializer_list data) : array_desc_( std::make_shared( Shape{static_cast(data.size())}, float32)) { init(data.begin()); } array::array(std::initializer_list data, Dtype dtype) : array_desc_( std::make_shared( Shape{static_cast(data.size())}, dtype)) { init(data.begin()); } array::array( void* data, Shape shape, Dtype dtype, const std::function& deleter) : array_desc_(std::make_shared(std::move(shape), dtype)) { auto buffer = allocator::make_buffer(data, nbytes()); if (buffer.ptr() == nullptr) { set_data(allocator::malloc(nbytes())); auto ptr = static_cast(data); std::copy(ptr, ptr + nbytes(), this->data()); deleter(data); } else { auto wrapped_deleter = [deleter](allocator::Buffer buffer) { auto ptr = buffer.raw_ptr(); allocator::release(buffer); return deleter(ptr); }; set_data(buffer, std::move(wrapped_deleter)); } } /* Build an array from a shared buffer */ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) : array_desc_(std::make_shared(std::move(shape), dtype)) { set_data(data, deleter); } void array::detach() { array_desc_->primitive = nullptr; for (auto& s : array_desc_->siblings) { s.array_desc_->primitive = nullptr; } for (auto& s : array_desc_->siblings) { s.array_desc_->inputs.clear(); s.array_desc_->siblings.clear(); s.array_desc_->position = 0; } array_desc_->inputs.clear(); array_desc_->siblings.clear(); array_desc_->position = 0; } bool array::is_available() const { if (status() == Status::available) { return true; } else if ( status() == Status::evaluated && (!event().valid() || event().is_signaled())) { detach_event(); set_status(Status::available); return true; } return false; } void array::wait() { if (!is_available()) { if (event().valid()) { event().wait(); detach_event(); } set_status(Status::available); } } void array::eval() { // Ensure the array is ready to be read if (status() == Status::unscheduled) { mlx::core::eval({*this}); } else { wait(); } } bool array::is_tracer() const { return (array_desc_->is_tracer && detail::in_tracing()) || detail::retain_graph(); } void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data = std::make_shared(buffer, d); array_desc_->offset = 0; array_desc_->data_size = size(); array_desc_->flags.contiguous = true; array_desc_->flags.row_contiguous = true; auto max_dim = std::max_element(shape().begin(), shape().end()); array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim; } void array::set_data( allocator::Buffer buffer, size_t data_size, Strides strides, Flags flags, Deleter d) { array_desc_->data = std::make_shared(buffer, d); array_desc_->offset = 0; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; } void array::copy_shared_buffer( const array& other, const Strides& strides, Flags flags, size_t data_size, int64_t offset /* = 0 */) { array_desc_->data = other.array_desc_->data; array_desc_->strides = strides; array_desc_->flags = flags; array_desc_->data_size = data_size; array_desc_->offset = sizeof(char) * itemsize() * offset + other.array_desc_->offset; } void array::copy_shared_buffer(const array& other) { copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } array::~array() { if (array_desc_ == nullptr) { return; } // Detached/detaching if (array_desc_->primitive == nullptr) { return; } // Break circular reference for non-detached arrays with siblings if (auto n = siblings().size(); n > 0) { bool do_detach = true; // If all siblings have siblings.size() references except // the one we are currently destroying (which has siblings.size() + 1) // then there are no more external references do_detach &= (array_desc_.use_count() == (n + 1)); for (auto& s : siblings()) { do_detach &= (s.array_desc_.use_count() == n); if (!do_detach) { break; } } if (do_detach) { for (auto& s : siblings()) { for (auto& ss : s.siblings()) { // Set to null here to avoid descending into array destructor // for siblings ss.array_desc_ = nullptr; } s.array_desc_->siblings.clear(); } } } } void array::ArrayDesc::init() { strides.resize(shape.size()); size = 1; for (int i = shape.size() - 1; i >= 0; --i) { strides[i] = size; size *= shape[i]; } for (const auto& in : inputs) { is_tracer |= in.is_tracer(); } } array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype) : shape(std::move(shape)), dtype(dtype), status(Status::available) { init(); } array::ArrayDesc::ArrayDesc( Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs) : shape(std::move(shape)), dtype(dtype), primitive(std::move(primitive)), status(Status::unscheduled), inputs(std::move(inputs)) { init(); } array::ArrayDesc::~ArrayDesc() { // When an array description is destroyed it will delete a bunch of arrays // that may also destroy their corresponding descriptions and so on and so // forth. // // This calls recursively the destructor and can result in stack overflow, we // instead put them in a vector and destroy them one at a time resulting in a // max stack depth of 2. if (inputs.empty()) { return; } std::vector> for_deletion; auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) { std::unordered_map input_map; for (array& a : ad.inputs) { if (a.array_desc_) { input_map.insert({a.id(), a}); for (auto& s : a.siblings()) { input_map.insert({s.id(), s}); } } } ad.inputs.clear(); for (auto& [_, a] : input_map) { bool is_deletable = (a.array_desc_.use_count() <= a.siblings().size() + 1); // An array with siblings is deletable only if all of its siblings // are deletable for (auto& s : a.siblings()) { if (!is_deletable) { break; } int is_input = (input_map.find(s.id()) != input_map.end()); is_deletable &= s.array_desc_.use_count() <= a.siblings().size() + is_input; } if (is_deletable) { for_deletion.push_back(std::move(a.array_desc_)); } } }; append_deletable_inputs(*this); while (!for_deletion.empty()) { // top is going to be deleted at the end of the block *after* the arrays // with inputs have been moved into the vector auto top = std::move(for_deletion.back()); for_deletion.pop_back(); append_deletable_inputs(*top); // Clear out possible siblings to break circular references for (auto& s : top->siblings) { // Set to null here to avoid descending into top-level // array destructor for siblings s.array_desc_ = nullptr; } top->siblings.clear(); } } array::ArrayIterator::ArrayIterator(const array& arr, int idx) : arr(arr), idx(idx) { if (arr.ndim() == 0) { throw std::invalid_argument("Cannot iterate over 0-d array."); } } array::ArrayIterator::reference array::ArrayIterator::operator*() const { auto start = Shape(arr.ndim(), 0); auto end = arr.shape(); auto shape = arr.shape(); shape.erase(shape.begin()); start[0] = idx; end[0] = idx + 1; return reshape(slice(arr, start, end), shape); }; } // namespace mlx::core ================================================ FILE: mlx/array.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include #include #include #include #include "mlx/allocator.h" #include "mlx/api.h" #include "mlx/dtype.h" #include "mlx/event.h" #include "mlx/small_vector.h" namespace mlx::core { // Forward declaration class Primitive; using Deleter = std::function; using ShapeElem = int32_t; using Shape = SmallVector; using Strides = SmallVector; class MLX_API array { /* An array is really a node in a graph. It contains a shared ArrayDesc * object */ public: /** Construct a scalar array with zero dimensions. */ template explicit array(T val, Dtype dtype = TypeToDtype()); /* Special case since std::complex can't be implicitly converted to other * types. */ explicit array(const std::complex& val, Dtype dtype = complex64); template explicit array( It data, Shape shape, Dtype dtype = TypeToDtype::value_type>()); template explicit array(std::initializer_list data, Dtype dtype = TypeToDtype()); /* Special case so empty lists default to float32. */ explicit array(std::initializer_list data); /* Special case so array({}, type) is an empty array. */ explicit array(std::initializer_list data, Dtype dtype); template explicit array( std::initializer_list data, Shape shape, Dtype dtype = TypeToDtype()); /* Build an array from a raw pointer. The constructor will attempt to use the * input data without a copy. The deleter will be called when the array no * longer needs the underlying memory - after the array is destroyed in the * no-copy case and after the copy otherwise. */ explicit array( void* data, Shape shape, Dtype dtype, const std::function& deleter); /* Build an array from a buffer */ explicit array( allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter = allocator::free); /** Assignment to rvalue does not compile. */ array& operator=(const array& other) && = delete; array& operator=(array&& other) && = delete; /** Default copy and move constructors otherwise. */ array& operator=(array&& other) & = default; array(const array& other) = default; array(array&& other) = default; array& operator=(const array& other) & { if (this->id() != other.id()) { this->array_desc_ = other.array_desc_; } return *this; } /** The size of the array's datatype in bytes. */ size_t itemsize() const { return size_of(dtype()); } /** The number of elements in the array. */ size_t size() const { return array_desc_->size; } /** The number of bytes in the array. */ size_t nbytes() const { return size() * itemsize(); } /** The number of dimensions of the array. */ size_t ndim() const { return array_desc_->shape.size(); } /** The shape of the array as a vector of integers. */ const Shape& shape() const { return array_desc_->shape; } /** * Get the size of the corresponding dimension. * * This function supports negative indexing and provides * bounds checking. */ auto shape(int dim) const { return shape().at(dim < 0 ? dim + static_cast(ndim()) : dim); } /** The strides of the array. */ const Strides& strides() const { return array_desc_->strides; } /** * Get the stride of the corresponding dimension. * * This function supports negative indexing and provides * bounds checking. */ auto strides(int dim) const { return strides().at(dim < 0 ? dim + static_cast(ndim()) : dim); } /** Get the arrays data type. */ Dtype dtype() const { return array_desc_->dtype; } /** Evaluate the array. */ void eval(); /** Get the value from a scalar array. */ template T item(); template T item() const; struct MLX_API ArrayIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = size_t; using value_type = const array; using reference = value_type; explicit ArrayIterator(const array& arr, int idx = 0); reference operator*() const; ArrayIterator& operator+(difference_type diff) { idx += diff; return *this; } ArrayIterator& operator++() { idx++; return *this; } friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { return a.arr.id() == b.arr.id() && a.idx == b.idx; } friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { return !(a == b); } private: const array& arr; int idx; }; ArrayIterator begin() const { return ArrayIterator(*this); } ArrayIterator end() const { return ArrayIterator(*this, shape(0)); } /** * The following methods should be used with caution. * They are intended for use by the backend implementation and the * API may change. */ array( Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs); static std::vector make_arrays( std::vector shapes, const std::vector& dtypes, const std::shared_ptr& primitive, const std::vector& inputs); /** * Get a new array that refers to the same data as the input but with a * non-owning pointer to it. Note the array is detached from the graph and has * no inputs, siblings or primitive. */ static array unsafe_weak_copy(const array& other); /** A unique identifier for an array. */ std::uintptr_t id() const { return reinterpret_cast(array_desc_.get()); } /** A unique identifier for an arrays primitive. */ std::uintptr_t primitive_id() const { return reinterpret_cast(array_desc_->primitive.get()); } struct Data { allocator::Buffer buffer; Deleter d; Data(allocator::Buffer buffer, Deleter d = allocator::free) : buffer(buffer), d(d) {} // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; Data(Data&& o) : buffer(o.buffer), d(o.d) { o.buffer = allocator::Buffer(nullptr); o.d = [](allocator::Buffer) {}; } ~Data() { d(buffer); } }; struct Flags { // True iff there are no gaps in the underlying data. Each item // in the underlying data buffer belongs to at least one index. // // True iff: // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size() bool contiguous : 1; // True iff: // strides[-1] == 1 and // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in // range(ndim - 1)) bool row_contiguous : 1; // True iff: // strides[0] == 1 and // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in // range(1, ndim)) bool col_contiguous : 1; }; /** The array's primitive. */ Primitive& primitive() const { return *(array_desc_->primitive); } /** A shared pointer to the array's primitive. */ std::shared_ptr& primitive_ptr() const { return array_desc_->primitive; } /** Check if the array has an attached primitive or is a leaf node. */ bool has_primitive() const { return array_desc_->primitive != nullptr; } /** The array's inputs. */ const std::vector& inputs() const { return array_desc_->inputs; } std::vector& inputs() { return array_desc_->inputs; } /** True indicates the arrays buffer is safe to reuse */ bool is_donatable() const { return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); } /** The array's siblings. */ const std::vector& siblings() const { return array_desc_->siblings; } /** The array's siblings. */ std::vector& siblings() { return array_desc_->siblings; } /** The array's position in the sibling list. */ int sibling_position() const { return array_desc_->position; } void set_siblings(std::vector siblings, uint16_t position) { array_desc_->siblings = std::move(siblings); array_desc_->position = position; } /** The outputs of the array's primitive (i.e. this array and * its siblings) in the order the primitive expects. */ std::vector outputs() const { auto idx = array_desc_->position; std::vector outputs; outputs.reserve(siblings().size() + 1); outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx); outputs.push_back(*this); outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); return outputs; } /** Detach the array from the graph. */ void detach(); /** Get the Flags bit-field. */ const Flags& flags() const { return array_desc_->flags; } /** The size (in elements) of the underlying buffer the array points to. * * This can be different than the actual size of the array if the array has * been broadcast or irregularly strided. If ``first`` is the offset into * the data buffer of the first element of the array (i.e. the offset * corresponding to ``arr[0, 0, ...]``) and last is the offset into the * data buffer of the last element of the array (i.e. the offset * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. * Note, ``data_size`` is in units of ``item_size`` (not bytes). **/ size_t data_size() const { return array_desc_->data_size; } allocator::Buffer& buffer() { return array_desc_->data->buffer; } const allocator::Buffer& buffer() const { return array_desc_->data->buffer; } size_t buffer_size() const { return allocator::allocator().size(buffer()); } // Return the shared pointer to the array::Data struct const std::shared_ptr& data_shared_ptr() const { return array_desc_->data; } // Return a raw pointer to the arrays data. This function may do a copy if // the underlying buffer is not accessible on the CPU. When accessing the // data for GPU kernels, be sure to use the correct method / function for the // given backend to access the GPU pointer. template T* data() { return reinterpret_cast( (static_cast(buffer().raw_ptr()) + array_desc_->offset)); } template const T* data() const { return const_cast(*this).data(); } int64_t offset() const { return array_desc_->offset; } enum Status { // The output of a computation which has not been scheduled. // For example, the status of `x` in `auto x = a + b`. unscheduled, // The array's `eval_*` function has been run, but the computation is not // necessarily complete. The array will have memory allocated and if it is // not a tracer then it will be detached from the graph. evaluated, // If the array is the output of a computation then the computation // is complete. Constant arrays are always available (e.g. `array({1, 2, // 3})`) available }; // Check if the array is safe to read. bool is_available() const; // Wait on the array to be available. After this `is_available` returns // `true`. void wait(); Status status() const { return array_desc_->status; } void set_status(Status s) const { array_desc_->status = s; } // Get the array's shared event Event& event() const { return array_desc_->event; } // Attach an event to a not yet evaluated array void attach_event(Event e) const { array_desc_->event = std::move(e); } void detach_event() const { array_desc_->event = Event{}; } // Mark the array as a tracer array (true) or not. void set_tracer(bool is_tracer) { array_desc_->is_tracer = is_tracer; } // Check if the array is a tracer array bool is_tracer() const; void set_data(allocator::Buffer buffer, Deleter d = allocator::free); void set_data( allocator::Buffer buffer, size_t data_size, Strides strides, Flags flags, Deleter d = allocator::free); void copy_shared_buffer( const array& other, const Strides& strides, Flags flags, size_t data_size, int64_t offset = 0); void copy_shared_buffer(const array& other); void overwrite_descriptor(const array& other) { array_desc_ = other.array_desc_; } ~array(); private: // Initialize the arrays data template void init(const It src); struct MLX_API ArrayDesc { Shape shape; Strides strides; size_t size; Dtype dtype; std::shared_ptr primitive; Status status; // An event on the array used for synchronization Event event; // Indicates an array is being used in a graph transform // and should not be detached from the graph bool is_tracer{false}; // This is a shared pointer so that *different* arrays // can share the underlying data buffer. std::shared_ptr data; // Offset from beginning of data pointer int64_t offset{0}; // The size in elements of the data buffer the array accesses size_t data_size{0}; // Contains useful meta data about the array Flags flags{true, true, true}; std::vector inputs; // An array to keep track of the siblings from a multi-output // primitive. std::vector siblings; // The arrays position in the output list uint32_t position{0}; explicit ArrayDesc(Shape shape, Dtype dtype); explicit ArrayDesc( Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs); ~ArrayDesc(); private: // Initialize size, strides, and other metadata void init(); }; // The ArrayDesc contains the details of the materialized array including the // shape, strides, the data type. It also includes // the primitive which knows how to compute the array's data from its inputs // and the list of array's inputs for the primitive. std::shared_ptr array_desc_; }; template array::array(T val, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared(Shape{}, dtype)) { init(&val); } template array::array( It data, Shape shape, Dtype dtype /* = TypeToDtype::value_type>() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { init(data); } template array::array( std::initializer_list data, Dtype dtype /* = TypeToDtype() */) : array_desc_( std::make_shared( Shape{static_cast(data.size())}, dtype)) { init(data.begin()); } template array::array( std::initializer_list data, Shape shape, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { if (data.size() != size()) { throw std::invalid_argument( "Data size and provided shape mismatch in array construction."); } init(data.begin()); } template T array::item() { if (size() != 1) { throw std::invalid_argument("item can only be called on arrays of size 1."); } eval(); return *data(); } template T array::item() const { if (size() != 1) { throw std::invalid_argument("item can only be called on arrays of size 1."); } if (status() == Status::unscheduled) { throw std::invalid_argument( "item() const can only be called on evaled arrays"); } const_cast(this)->eval(); return *data(); } template void array::init(It src) { set_data(allocator::malloc(size() * size_of(dtype()))); switch (dtype()) { case bool_: std::copy(src, src + size(), data()); break; case uint8: std::copy(src, src + size(), data()); break; case uint16: std::copy(src, src + size(), data()); break; case uint32: std::copy(src, src + size(), data()); break; case uint64: std::copy(src, src + size(), data()); break; case int8: std::copy(src, src + size(), data()); break; case int16: std::copy(src, src + size(), data()); break; case int32: std::copy(src, src + size(), data()); break; case int64: std::copy(src, src + size(), data()); break; case float16: std::copy(src, src + size(), data()); break; case float32: std::copy(src, src + size(), data()); break; case float64: std::copy(src, src + size(), data()); break; case bfloat16: std::copy(src, src + size(), data()); break; case complex64: std::copy(src, src + size(), data()); break; } } /* Utilities for determining whether a template parameter is array. */ template inline constexpr bool is_array_v = std::is_same_v>, array>; template inline constexpr bool is_arrays_v = (is_array_v && ...); template using enable_for_arrays_t = typename std::enable_if_t>; } // namespace mlx::core ================================================ FILE: mlx/backend/common/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) ================================================ FILE: mlx/backend/common/binary.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" namespace mlx::core { enum class BinaryOpType { ScalarScalar, ScalarVector, VectorScalar, VectorVector, General, }; inline BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType bopt; if (a.data_size() == 1 && b.data_size() == 1) { bopt = BinaryOpType::ScalarScalar; } else if (a.data_size() == 1 && b.flags().contiguous) { bopt = BinaryOpType::ScalarVector; } else if (b.data_size() == 1 && a.flags().contiguous) { bopt = BinaryOpType::VectorScalar; } else if ( (a.flags().row_contiguous && b.flags().row_contiguous) || (a.flags().col_contiguous && b.flags().col_contiguous)) { bopt = BinaryOpType::VectorVector; } else { bopt = BinaryOpType::General; } return bopt; } inline void set_binary_op_output_data( const array& a, const array& b, array& out, BinaryOpType bopt, std::function mallocfn = allocator::malloc) { bool b_donatable = is_donatable(b, out); bool a_donatable = is_donatable(a, out); switch (bopt) { case BinaryOpType::ScalarScalar: out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( mallocfn(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); } break; case BinaryOpType::VectorScalar: if (a_donatable) { out.copy_shared_buffer(a); } else { out.set_data( mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); } break; case BinaryOpType::VectorVector: if (a_donatable) { out.copy_shared_buffer(a); } else if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); } break; case BinaryOpType::General: if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { out.copy_shared_buffer(a); } else if ( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { out.copy_shared_buffer(b); } else { out.set_data(mallocfn(out.nbytes())); } break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/common/broadcasting.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/utils.h" namespace mlx::core { void broadcast(const array& in, array& out) { if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } Strides strides(out.ndim(), 0); int diff = out.ndim() - in.ndim(); for (int i = in.ndim() - 1; i >= 0; --i) { strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; } auto flags = in.flags(); if (out.size() > in.size()) { flags.row_contiguous = flags.col_contiguous = false; } out.copy_shared_buffer(in, strides, flags, in.data_size()); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/broadcasting.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void broadcast(const array& in, array& out); } // namespace mlx::core ================================================ FILE: mlx/backend/common/buffer_cache.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include namespace mlx::core { template class BufferCache { public: BufferCache( size_t page_size, std::function get_size, std::function free) : page_size_(page_size), get_size_(std::move(get_size)), free_(std::move(free)) {} ~BufferCache() { clear(); } BufferCache(const BufferCache&) = delete; BufferCache& operator=(const BufferCache&) = delete; T* reuse_from_cache(size_t size) { // Find the closest buffer in pool. auto it = buffer_pool_.lower_bound(size); if (it == buffer_pool_.end() || it->first >= std::min(2 * size, size + 2 * page_size_)) { return nullptr; } // Collect from the cache. T* buf = it->second->buf; pool_size_ -= it->first; // Remove from record. remove_from_list(it->second); buffer_pool_.erase(it); return buf; } void recycle_to_cache(T* buf) { assert(buf); // Add to cache. BufferHolder* bh = new BufferHolder(buf); add_at_head(bh); size_t size = get_size_(buf); pool_size_ += size; buffer_pool_.emplace(size, bh); } int release_cached_buffers(size_t min_bytes_to_free) { if (min_bytes_to_free >= 0.9 * pool_size_) { return clear(); } else { int n_release = 0; size_t total_bytes_freed = 0; while (tail_ && (total_bytes_freed < min_bytes_to_free)) { // Release buffer. size_t size = get_size_(tail_->buf); total_bytes_freed += size; free_(tail_->buf); n_release++; // Remove from record. auto its = buffer_pool_.equal_range(size); auto it = std::find_if(its.first, its.second, [this](const auto& el) { return el.second == tail_; }); assert(it != buffer_pool_.end()); buffer_pool_.erase(it); remove_from_list(tail_); } pool_size_ -= total_bytes_freed; return n_release; } } int clear() { int n_release = 0; for (auto& [size, holder] : buffer_pool_) { free_(holder->buf); n_release++; delete holder; } buffer_pool_.clear(); pool_size_ = 0; head_ = nullptr; tail_ = nullptr; return n_release; } size_t cache_size() const { return pool_size_; } size_t page_size() const { return page_size_; } private: struct BufferHolder { public: explicit BufferHolder(T* buf_) : buf(buf_) {} BufferHolder* prev{nullptr}; BufferHolder* next{nullptr}; T* buf; }; void add_at_head(BufferHolder* to_add) { if (!head_) { head_ = to_add; tail_ = to_add; } else { head_->prev = to_add; to_add->next = head_; head_ = to_add; } } void remove_from_list(BufferHolder* to_remove) { if (to_remove->prev && to_remove->next) { // if middle to_remove->prev->next = to_remove->next; to_remove->next->prev = to_remove->prev; } else if (to_remove->prev && to_remove == tail_) { // if tail tail_ = to_remove->prev; tail_->next = nullptr; } else if (to_remove == head_ && to_remove->next) { // if head head_ = to_remove->next; head_->prev = nullptr; } else if (to_remove == head_ && to_remove == tail_) { // if only element head_ = nullptr; tail_ = nullptr; } delete to_remove; } std::multimap buffer_pool_; BufferHolder* head_{nullptr}; BufferHolder* tail_{nullptr}; size_t pool_size_{0}; const size_t page_size_; std::function get_size_; std::function free_; }; } // namespace mlx::core ================================================ FILE: mlx/backend/common/common.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" namespace mlx::core { void AsStrided::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (!in.flags().row_contiguous) { // Just ensuring that inputs[0] came from the ops which would ensure the // input is row contiguous. throw std::runtime_error( "AsStrided must be used with row contiguous arrays only."); } // Compute the flags given the shape and strides bool row_contiguous = true, col_contiguous = true; size_t r = 1, c = 1; for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); r *= shape_[i]; c *= shape_[j]; } auto flags = in.flags(); // TODO: Compute the contiguous flag in a better way cause now we are // unnecessarily strict. flags.contiguous = row_contiguous || col_contiguous; flags.row_contiguous = row_contiguous; flags.col_contiguous = col_contiguous; // There is no easy way to compute the actual data size so we use out.size(). // The contiguous flag will almost certainly not be set so no code should // rely on data_size anyway. size_t data_size = out.size(); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } void Broadcast::eval(const std::vector& inputs, array& out) { broadcast(inputs[0], out); } void BroadcastAxes::eval(const std::vector& inputs, array& out) { broadcast(inputs[0], out); } void Copy::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.copy_shared_buffer(inputs[0]); } void CustomTransforms::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() > outputs.size()); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); i++, j++) { outputs[i].copy_shared_buffer(inputs[j]); } } void Depends::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() > outputs.size()); for (int i = 0; i < outputs.size(); i++) { outputs[i].copy_shared_buffer(inputs[i]); } } void ExpandDims::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; auto strides = in.strides(); for (auto ax : axes_) { strides.insert(strides.begin() + ax, 1); } out.copy_shared_buffer(in, strides, in.flags(), in.data_size()); } void NumberOfElements::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); double numel = 1; for (auto ax : axes_) { numel *= inputs[0].shape(ax); } if (inverted_) { numel = 1.0 / numel; } switch (out.dtype()) { case bool_: *out.data() = static_cast(numel); break; case uint8: *out.data() = static_cast(numel); break; case uint16: *out.data() = static_cast(numel); break; case uint32: *out.data() = static_cast(numel); break; case uint64: *out.data() = static_cast(numel); break; case int8: *out.data() = static_cast(numel); break; case int16: *out.data() = static_cast(numel); break; case int32: *out.data() = static_cast(numel); break; case int64: *out.data() = static_cast(numel); break; case float16: *out.data() = static_cast(numel); break; case float32: *out.data() = static_cast(numel); break; case bfloat16: *out.data() = static_cast(numel); break; case float64: *out.data() = static_cast(numel); break; case complex64: *out.data() = static_cast(numel); break; } } std::pair prepare_reshape(const array& in, const array& out) { // Special case for empty arrays or row contiguous arrays if (in.size() == 0 || in.flags().row_contiguous) { return {false, out.strides()}; } // Special case for scalars if (in.ndim() == 0) { return {false, Strides(out.ndim(), 0)}; } // Firstly let's collapse all the contiguous dimensions of the input auto [shape, strides] = collapse_contiguous_dims(in); // If shapes fit exactly in the contiguous dims then no copy is necessary so // let's check. Strides out_strides; bool copy_necessary = false; int j = 0; for (int i = 0; i < out.ndim(); i++) { int N = out.shape(i); if (j < shape.size() && shape[j] % N == 0) { shape[j] /= N; out_strides.push_back(shape[j] * strides[j]); j += (shape[j] == 1); } else if (N == 1) { // i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0 out_strides.push_back(out_strides.back()); } else { copy_necessary = true; break; } } return {copy_necessary, out_strides}; } void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out) { auto flags = in.flags(); if (flags.row_contiguous) { // For row contiguous reshapes: // - Shallow copy the buffer // - If reshaping into a vector (all singleton dimensions except one) it // becomes col contiguous again. auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; } out.copy_shared_buffer(in, out_strides, flags, in.data_size()); } void Split::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); auto& in = inputs[0]; auto compute_new_flags = [](const auto& shape, const auto& strides, size_t in_data_size, auto flags) { size_t data_size = 1; size_t f_stride = 1; size_t b_stride = 1; flags.row_contiguous = true; flags.col_contiguous = true; for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1; flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; f_stride *= shape[i]; b_stride *= shape[ri]; if (strides[i] > 0) { data_size *= shape[i]; } } if (data_size == 1) { // Broadcasted scalar array is contiguous. flags.contiguous = true; } else if (data_size == in_data_size) { // Means we sliced a broadcasted dimension so leave the "no holes" flag // alone. } else { // We sliced something. So either we are row or col contiguous or we // punched a hole. flags.contiguous &= flags.row_contiguous || flags.col_contiguous; } return std::pair{flags, data_size}; }; std::vector indices(1, 0); indices.insert(indices.end(), indices_.begin(), indices_.end()); for (int i = 0; i < indices.size(); i++) { size_t offset = indices[i] * in.strides()[axis_]; auto [new_flags, data_size] = compute_new_flags( outputs[i].shape(), in.strides(), in.data_size(), in.flags()); outputs[i].copy_shared_buffer( in, in.strides(), new_flags, data_size, offset); } } void Squeeze::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; Strides strides; for (int i = 0, j = 0; i < in.ndim(); ++i) { if (j < axes_.size() && i == axes_[j]) { j++; } else { strides.push_back(in.strides(i)); } } out.copy_shared_buffer(in, strides, in.flags(), in.data_size()); } void StopGradient::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.copy_shared_buffer(inputs[0]); } void Transpose::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); Strides out_strides(out.ndim()); auto& in = inputs[0]; for (int ax = 0; ax < axes_.size(); ++ax) { out_strides[ax] = in.strides()[axes_[ax]]; } // Conditions for {row/col}_contiguous // - array must be contiguous (no gaps) // - underlying buffer size should have the same size as the array // - cumulative product of shapes is equal to the strides (we can ignore axes // with size == 1) // - in the forward direction (column contiguous) // - in the reverse direction (row contiguous) // - vectors are both row and col contiguous (hence if both row/col are // true, they stay true) auto flags = in.flags(); if (flags.contiguous && in.data_size() == in.size()) { int64_t f_stride = 1; int64_t b_stride = 1; flags.col_contiguous = true; flags.row_contiguous = true; for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); f_stride *= out.shape(i); flags.row_contiguous &= (out_strides[ri] == b_stride || out.shape(ri) == 1); b_stride *= out.shape(ri); } } out.copy_shared_buffer(in, out_strides, flags, in.data_size()); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/compiled.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/utils.h" namespace mlx::core { void print_constant(std::ostream& os, const array& x) { switch (x.dtype()) { case float32: return print_float_constant(os, x); case float16: return print_float_constant(os, x); case bfloat16: return print_float_constant(os, x); case float64: return print_float_constant(os, x); case complex64: return print_complex_constant(os, x); case int8: os << static_cast(x.item()); return; case int16: return print_int_constant(os, x); case int32: return print_int_constant(os, x); case int64: return print_int_constant(os, x); case uint8: os << static_cast(x.item()); return; case uint16: return print_int_constant(os, x); case uint32: return print_int_constant(os, x); case uint64: return print_int_constant(os, x); case bool_: os << std::boolalpha << x.item(); return; default: throw std::runtime_error("Unsupported constant type"); } } std::string get_type_string(Dtype d) { switch (d) { case float32: return "float"; case float16: return "float16_t"; case bfloat16: return "bfloat16_t"; case float64: return "double"; case complex64: return "complex64_t"; case bool_: return "bool"; case int8: return "int8_t"; case int16: return "int16_t"; case int32: return "int32_t"; case int64: return "int64_t"; case uint8: return "uint8_t"; case uint16: return "uint16_t"; case uint32: return "uint32_t"; case uint64: return "uint64_t"; default: { std::ostringstream msg; msg << "Unsupported compilation type " << d; throw std::runtime_error(msg.str()); } } } bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape) { bool contiguous = true; bool all_contig = true; bool all_row_contig = true; bool all_col_contig = true; int non_scalar_inputs = 0; for (const auto& x : inputs) { if (is_scalar(x)) { continue; } non_scalar_inputs++; bool shape_eq = x.shape() == shape; all_contig &= (x.flags().contiguous && shape_eq); all_row_contig &= (x.flags().row_contiguous && shape_eq); all_col_contig &= (x.flags().col_contiguous && shape_eq); } if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { contiguous = false; } else if (non_scalar_inputs == 1 && !all_contig) { contiguous = false; } else if (non_scalar_inputs == 0 && !shape.empty()) { contiguous = false; } return contiguous; } void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, bool contiguous, const std::function& mallocfn /* = allocator::malloc */) { if (contiguous) { int o = 0; Strides strides; size_t data_size; array::Flags flags; for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { auto& in = inputs[i]; // Conditions for donation // - Correct size // - Not a scalar // - Donatable // - Not a constant if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && in.is_donatable() && !is_constant(i)) { outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs if (strides.empty() && in.size() == outputs[0].size()) { strides = in.strides(); flags = in.flags(); data_size = in.data_size(); } } for (; o < outputs.size(); ++o) { outputs[o].set_data( mallocfn(data_size * outputs[o].itemsize()), data_size, strides, flags); } } else { int o = 0; for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { auto& in = inputs[i]; // Conditions for donation // - Row contiguous // - Donatable // - Correct size // - Not a constant if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && !is_constant(i)) { outputs[o].copy_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size()); o++; } } for (; o < outputs.size(); ++o) { outputs[o].set_data(mallocfn(outputs[o].nbytes())); } } } std::tuple> compiled_collapse_contiguous_dims( const std::vector& inputs, const array& out, const std::function& is_constant) { const Shape& shape = out.shape(); bool contiguous = compiled_check_contiguity(inputs, shape); if (contiguous) { return {true, shape, {}}; } std::vector strides_vec{out.strides()}; for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants. if (is_constant(i)) { continue; } // Skip scalar inputs. const auto& x = inputs[i]; if (is_scalar(x)) { continue; } // Broadcast the inputs to the output shape. Strides xstrides; size_t j = 0; for (; j < shape.size() - x.ndim(); ++j) { if (shape[j] == 1) { xstrides.push_back(out.strides()[j]); } else { xstrides.push_back(0); } } for (size_t i = 0; i < x.ndim(); ++i, ++j) { if (x.shape(i) == 1) { if (shape[j] == 1) { xstrides.push_back(out.strides()[j]); } else { xstrides.push_back(0); } } else { xstrides.push_back(x.strides()[i]); } } strides_vec.push_back(std::move(xstrides)); } auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX); return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))}; } bool compiled_use_large_index( const std::vector& inputs, const std::vector& outputs, bool contiguous) { if (contiguous) { size_t max_size = 0; for (const auto& in : inputs) { max_size = std::max(max_size, in.data_size()); } return max_size > UINT32_MAX; } else { size_t max_size = 0; for (const auto& o : outputs) { max_size = std::max(max_size, o.size()); } return max_size > UINT32_MAX; } } } // namespace mlx::core ================================================ FILE: mlx/backend/common/compiled.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/array.h" #include "mlx/primitives.h" namespace mlx::core { inline bool is_static_cast(const Primitive& p) { return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } std::string get_type_string(Dtype d); template void print_float_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); if constexpr (std::is_same_v) { os << std::setprecision(std::numeric_limits::digits10 + 1); } else { os << std::setprecision(std::numeric_limits::digits10 + 1); } os << x.item() << std::setprecision(old_precision); } template void print_int_constant(std::ostream& os, const array& x) { os << x.item(); } template void print_complex_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); T constant = x.item(); os << get_type_string(x.dtype()) << "(" << std::setprecision(std::numeric_limits::digits10 + 1) << constant.real() << ", " << constant.imag() << ")" << std::setprecision(old_precision); } void print_constant(std::ostream& os, const array& x); inline bool is_scalar(const array& x) { return x.ndim() == 0; } // Check if we can use a contiguous operation given inputs and the output shape bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape); // Allocate space for the outputs possibly with input donation void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, bool contiguous, const std::function& mallocfn = allocator::malloc); // Collapse contiguous dims ignoring scalars and constants. std::tuple> compiled_collapse_contiguous_dims( const std::vector& inputs, const array& out, const std::function& is_constant); // Return whether the kernel should use large index. bool compiled_use_large_index( const std::vector& inputs, const std::vector& outputs, bool contiguous); } // namespace mlx::core ================================================ FILE: mlx/backend/common/copy.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/backend/common/utils.h" namespace mlx::core { enum class CopyType { // Copy a raw scalar input into the full contiguous output Scalar, // Copy the raw input buffer contiguously into a raw output buffer of the same // size Vector, // Copy the full virtual input to the full contiguous output General, // Copy the full virtual input to the full virtual output. We assume the // input and output have the same shape. GeneralGeneral }; inline bool set_copy_output_data( const array& in, array& out, CopyType ctype, std::function mallocfn = allocator::malloc) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. if (is_donatable(in, out)) { out.copy_shared_buffer(in); return true; } else { out.set_data( mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); return false; } } else { out.set_data(mallocfn(out.nbytes())); return false; } } } // namespace mlx::core ================================================ FILE: mlx/backend/common/hadamard.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/utils.h" namespace mlx::core { // From http://neilsloane.com/hadamard/ constexpr std::string_view h12 = R"( +-++++++++++ --+-+-+-+-+- +++-++----++ +---+--+-++- +++++-++---- +-+---+--+-+ ++--+++-++-- +--++---+--+ ++----+++-++ +--+-++---+- ++++----+++- +-+--+-++--- )"; constexpr std::string_view h20 = R"( +----+----++--++-++- -+----+---+++---+-++ --+----+---+++-+-+-+ ---+----+---+++++-+- ----+----++--++-++-+ -+++++-----+--+++--+ +-+++-+---+-+--+++-- ++-++--+---+-+--+++- +++-+---+---+-+--+++ ++++-----++--+-+--++ --++-+-++-+-----++++ ---++-+-++-+---+-+++ +---++-+-+--+--++-++ ++---++-+----+-+++-+ -++---++-+----+++++- -+--+--++-+----+---- +-+-----++-+----+--- -+-+-+---+--+----+-- --+-+++------+----+- +--+--++------+----+ )"; constexpr std::string_view h28 = R"( +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++- --+-----+++---+-+-+----+--++ ---+-----+++---+-+-+-+--+--+ ----+-----+++---+-+-+++--+-- -----+-----++++--+-+--++--+- ------++----++-+--+-+--++--+ --++++-+-------++--+++-+--+- ---++++-+-----+-++--+-+-+--+ +---+++--+----++-++--+-+-+-- ++---++---+----++-++--+-+-+- +++---+----+----++-++--+-+-+ ++++--------+-+--++-++--+-+- -++++--------+++--++--+--+-+ -+-++-++--++--+--------++++- +-+-++--+--++--+--------++++ -+-+-++--+--++--+----+---+++ +-+-+-++--+--+---+---++---++ ++-+-+-++--+------+--+++---+ -++-+-+-++--+------+-++++--- +-++-+---++--+------+-++++-- -++--++-+-++-+++----++------ +-++--++-+-++-+++-----+----- ++-++---+-+-++-+++-----+---- -++-++-+-+-+-+--+++-----+--- --++-++++-+-+----+++-----+-- +--++-+-++-+-+----+++-----+- ++--++-+-++-+-+----++------+ )"; inline const std::map hadamard_matrices() { return {{12, h12}, {20, h20}, {28, h28}}; } inline std::pair decompose_hadamard(int n) { // n = m*2^k int m = 1; if (!is_power_of_2(n)) { auto h_matrices = hadamard_matrices(); for (auto [factor, _] : h_matrices) { if (n % factor == 0) { m = factor; n /= factor; break; } } if (m == 1) { throw std::invalid_argument( "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } if (n > (1 << 26)) { throw std::invalid_argument( "[hadamard] Only supports n = m*2^k where k <= 26"); } return {n, m}; } } // namespace mlx::core ================================================ FILE: mlx/backend/common/load.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include "mlx/primitives.h" #include "mlx/scheduler.h" namespace { template void swap_endianness(uint8_t* data_bytes, size_t N) { struct Elem { uint8_t bytes[scalar_size]; }; Elem* data = reinterpret_cast(data_bytes); for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < (scalar_size / 2); j++) { std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); } } } } // namespace namespace mlx::core { void Load::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto read_task = [out_ptr = out.data(), size = out.size(), itemsize = out.itemsize(), offset = offset_, reader = reader_, swap_endianness_ = swap_endianness_]() mutable { reader->read(out_ptr, size * itemsize, offset); if (swap_endianness_) { switch (itemsize) { case 2: swap_endianness<2>(reinterpret_cast(out_ptr), size); break; case 4: swap_endianness<4>(reinterpret_cast(out_ptr), size); break; case 8: swap_endianness<8>(reinterpret_cast(out_ptr), size); break; } } }; auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/matmul.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/common/utils.h" #include "mlx/utils.h" #include namespace mlx::core { inline std::tuple collapse_batches( const array& a, const array& b) { if (a.ndim() == 2) { return {Shape{1}, Strides{0}, Strides{0}}; } Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; auto [batch_shape, batch_strides] = collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); auto a_batch_strides = batch_strides[0]; auto b_batch_strides = batch_strides[1]; if (batch_shape.empty()) { batch_shape.push_back(1); a_batch_strides.push_back(0); b_batch_strides.push_back(0); } return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides); } inline std::tuple collapse_batches(const array& a, const array& b, const array& c) { if (a.ndim() == 2) { return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; } Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; auto [batch_shape, batch_strides] = collapse_contiguous_dims( A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); auto A_batch_stride = batch_strides[0]; auto B_batch_stride = batch_strides[1]; auto C_batch_stride = batch_strides[2]; if (batch_shape.empty()) { batch_shape.push_back(1); A_batch_stride.push_back(0); B_batch_stride.push_back(0); C_batch_stride.push_back(0); } return std::make_tuple( batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/quantized.h ================================================ // Copyright © 2026 Apple Inc. namespace mlx::core { inline constexpr short get_pack_factor(int bits, int wsize = 8) { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) { bool power_of_2_bits = (bits & (bits - 1)) == 0; return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/reduce.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/reduce.h" namespace mlx::core { std::pair shapes_without_reduction_axes( Shape shape, Strides strides, const std::vector& axes) { for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; shape.erase(shape.begin() + a); strides.erase(strides.begin() + a); } return std::make_pair(shape, strides); } std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes) { auto shape = x.shape(); auto strides = x.strides(); return shapes_without_reduction_axes( std::move(shape), std::move(strides), axes); } ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && x.flags().contiguous) { return ContiguousAllReduce; } // Row contiguous input so the output is row contiguous if (x.flags().row_contiguous) { // Merge consecutive axes Shape shape = {x.shape(axes[0])}; Strides strides = {x.strides()[axes[0]]}; for (int i = 1; i < axes.size(); i++) { if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { shape.back() *= x.shape(axes[i]); strides.back() = x.strides()[axes[i]]; } else { shape.push_back(x.shape(axes[i])); strides.push_back(x.strides()[axes[i]]); } } // Remove singleton axes from the plan for (int i = shape.size() - 1; i >= 0; i--) { if (shape[i] == 1) { shape.erase(shape.begin() + i); strides.erase(strides.begin() + i); } } if (strides.back() == 1) { return ReductionPlan(ContiguousReduce, shape, strides); } else if (strides.back() > 1) { return ReductionPlan(ContiguousStridedReduce, shape, strides); } } // Let's check if we can optimize our access patterns // // 1. We have a reduction axis with stride 1. Simply call // GeneralContiguousReduce and be done with it. // 2. We have transpositions and we are not reducing over the axis with // stride 1. However, we are reducing over an axis where everything is // contiguous in memory to the right of that axis. We can call strided // reduce and be done with it. // 2. We have weird transpositions and expands. Copy the strides to the // output, then call strided reduce. // Sort reduction axes by stride in order to merge them and figure out if we // have a contiguous reduction. std::vector> reductions; for (auto a : axes) { if (x.shape(a) > 1) { reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); } } std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { bool a_is_zero = a.second == 0; bool b_is_zero = b.second == 0; return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second; }); // Extract the two smallest and try to merge them in case the contiguous // reduction can be bigger than just the last axis. for (int i = reductions.size() - 1; i >= 1; i--) { auto a = reductions[i]; auto b = reductions[i - 1]; // b.stride = a.shape * a.stride then a and b are contiguous if (b.second == a.first * a.second) { reductions.erase(reductions.begin() + i); reductions[i - 1] = std::make_pair(a.first * b.first, a.second); } } Shape shape; Strides strides; for (auto r : reductions) { shape.push_back(r.first); strides.push_back(r.second); } // We can call the contiguous reduction op for every weird way the input is // structured in the rest of the axes. if (strides.back() == 1) { return ReductionPlan(GeneralContiguousReduce, shape, strides); } // Delegate to the general strided reduction op if the axes after // strides.back() are contiguous. if (strides.back() > 1) { int64_t size = 1; bool have_expand = false; for (int i = x.ndim() - 1; i >= 0; i--) { if (axes.back() == i) { continue; } auto stride_i = x.strides()[i]; auto shape_i = x.shape(i); if (stride_i == 0) { if (shape_i == 1) { continue; } have_expand = true; break; } if (stride_i != size && shape_i != 1) { break; } size *= shape_i; } // In the case of an expanded dimension we are being conservative and // require the smallest reduction stride to be smaller than the maximum row // contiguous size. The reason is that we can't easily know if the reduced // axis is before or after an expanded dimension. if (size > strides.back() || (size == strides.back() && !have_expand)) { return ReductionPlan(GeneralStridedReduce, shape, strides); } } return ReductionPlan(GeneralReduce, shape, strides); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/reduce.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/backend/common/utils.h" namespace mlx::core { enum ReductionOpType { // Self-explanatory. Read everything and produce 1 output. ContiguousAllReduce, // The input is contiguous and the last axis is reduced // N1xR1xN2xR2x...xNnxRn ContiguousReduce, // The input is contiguous and the last axis is not reduced // R1xN1xR2xN2x...xRnxNn ContiguousStridedReduce, // The input is not contiguous but the last axis is and it is reduced so we // need to figure out the offsets but we can call the contiguous reduce after // that. // N3xR1xN1xR4x...xRn GeneralContiguousReduce, // The input is not contiguous but the last reduction axis and the last axis // are so we need to figure out the offset but we can call the strided reduce // after that. GeneralStridedReduce, // The input is not contiguous after the reduction axis and it may contain // 0-stride axes or transpositions. We could copy the strides and produce a // transposed outcome or we can read the input out of order and write the // output in order. GeneralReduce }; struct ReductionPlan { ReductionOpType type; Shape shape; Strides strides; ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_) : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} ReductionPlan(ReductionOpType type_) : type(type_) {} }; ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( Shape shape, Strides strides, const std::vector& axes); } // namespace mlx::core ================================================ FILE: mlx/backend/common/slicing.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/utils.h" namespace mlx::core { std::tuple prepare_slice( const array& in, const Shape& start_indices, const Shape& strides) { int64_t data_offset = 0; Strides inp_strides(in.ndim(), 0); for (int i = 0; i < in.ndim(); ++i) { data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; } return std::make_tuple(data_offset, inp_strides); } void shared_buffer_slice( const array& in, const Strides& out_strides, int64_t data_offset, size_t data_size, array& out) { // Compute row/col contiguity auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = check_contiguity(out.shape(), out_strides); auto flags = in.flags(); flags.row_contiguous = is_row_contiguous; flags.col_contiguous = is_col_contiguous; flags.contiguous = (no_bsx_size == data_size); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); } void slice( const array& in, array& out, const Shape& start_indices, const Shape& strides) { if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } // Calculate out strides, initial offset auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); // Get the location of the end based on the inp strides and out.shape() int64_t low_idx = 0; int64_t high_idx = 0; for (int i = 0; i < inp_strides.size(); ++i) { auto delta = inp_strides[i] * (out.shape()[i] - 1); if (inp_strides[i] > 0) { high_idx += delta; } else { low_idx += delta; } } int64_t data_size = (high_idx - low_idx) + 1; if (data_size < 0) { std::ostringstream msg; msg << "[slice] Computed invalid data size: " << data_size << "."; throw std::runtime_error(msg.str()); } shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/slicing.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { std::tuple prepare_slice( const array& in, const Shape& start_indices, const Shape& strides); void slice( const array& in, array& out, const Shape& start_indices, const Shape& strides); } // namespace mlx::core ================================================ FILE: mlx/backend/common/ternary.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" namespace mlx::core { // TODO: Add support for more combinations of input types. enum class TernaryOpType { ScalarScalarScalar, VectorVectorVector, VectorVectorScalar, VectorScalarVector, General, }; inline TernaryOpType get_ternary_op_type(const array& a, const array& b, const array& c) { TernaryOpType topt; if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { topt = TernaryOpType::ScalarScalarScalar; } else if ( (a.flags().row_contiguous && b.flags().row_contiguous && c.flags().row_contiguous) || (a.flags().col_contiguous && b.flags().col_contiguous && c.flags().col_contiguous)) { topt = TernaryOpType::VectorVectorVector; } else if ( b.data_size() == 1 && a.flags().row_contiguous && c.flags().row_contiguous) { topt = TernaryOpType::VectorScalarVector; } else if ( c.data_size() == 1 && a.flags().row_contiguous && b.flags().row_contiguous) { topt = TernaryOpType::VectorVectorScalar; } else { topt = TernaryOpType::General; } return topt; } inline void set_ternary_op_output_data( const array& a, const array& b, const array& c, array& out, TernaryOpType topt, std::function mallocfn = allocator::malloc) { auto maybe_donate = [&out](const array& x) { if (is_donatable(x, out)) { out.copy_shared_buffer(x); return true; } return false; }; switch (topt) { case TernaryOpType::ScalarScalarScalar: out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); break; case TernaryOpType::VectorVectorVector: if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { out.set_data( mallocfn(out.itemsize() * b.data_size()), b.data_size(), b.strides(), b.flags()); } break; case TernaryOpType::VectorVectorScalar: case TernaryOpType::VectorScalarVector: case TernaryOpType::General: // Try to donate an input which is row_contiguous if (!((a.flags().row_contiguous && maybe_donate(a)) || (b.flags().row_contiguous && maybe_donate(b)) || (c.flags().row_contiguous && maybe_donate(c)))) { out.set_data(mallocfn(out.nbytes())); } break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/common/unary.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" namespace mlx::core { inline void set_unary_output_data( const array& in, array& out, std::function mallocfn = allocator::malloc) { if (in.flags().contiguous) { if (is_donatable(in, out)) { out.copy_shared_buffer(in); } else { out.set_data( mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { out.set_data(mallocfn(out.nbytes())); } } } // namespace mlx::core ================================================ FILE: mlx/backend/common/utils.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/common/utils.h" namespace mlx::core { std::filesystem::path current_binary_dir() { static std::filesystem::path binary_dir = []() { Dl_info info; if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { throw std::runtime_error("Unable to get current binary dir."); } return std::filesystem::path(info.dli_fname).parent_path(); }(); return binary_dir; } std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, int64_t size_cap) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. Shape to_collapse; if (shape.size() > 0) { if (shape[0] != 1) { to_collapse.push_back(0); } size_t size = shape[0]; for (int i = 1; i < shape.size(); i++) { bool contiguous = true; size *= shape[i]; for (const auto& st : strides) { if (st[i] * shape[i] != st[i - 1] || size > size_cap) { contiguous = false; size = shape[i]; break; } } if (!contiguous) { to_collapse.push_back(-1); } if (shape[i] != 1) { to_collapse.push_back(i); } } to_collapse.push_back(-1); } Shape out_shape; std::vector out_strides(strides.size()); for (int i = 0;;) { while (i < to_collapse.size() && to_collapse[i] == -1) { ++i; }; if (i == to_collapse.size()) { break; } int current_shape = shape[to_collapse[i]]; int k = i; while (to_collapse[++k] != -1) { current_shape *= shape[to_collapse[k]]; } out_shape.push_back(current_shape); for (int j = 0; j < strides.size(); j++) { const auto& st = strides[j]; out_strides[j].push_back(st[to_collapse[k - 1]]); } i = k + 1; } if (!shape.empty() && out_shape.empty()) { out_shape.push_back(1); for (auto& out_stride : out_strides) { out_stride.push_back(0); } } return std::make_tuple(out_shape, out_strides); } std::pair collapse_contiguous_dims( const Shape& shape, const Strides& strides, int64_t size_cap) { Shape collapsed_shape; Strides collapsed_strides; if (shape.size() > 0) { collapsed_shape.push_back(shape[0]); collapsed_strides.push_back(strides[0]); for (int i = 1; i < shape.size(); i++) { if (shape[i] == 1) { continue; } else if ( strides[i] * shape[i] != collapsed_strides.back() || collapsed_shape.back() * static_cast(shape[i]) > size_cap) { collapsed_shape.push_back(shape[i]); collapsed_strides.push_back(strides[i]); } else { collapsed_shape.back() *= shape[i]; collapsed_strides.back() = strides[i]; } } } return std::make_pair(collapsed_shape, collapsed_strides); } std::pair collapse_contiguous_dims( const array& a, int64_t size_cap /* = std::numeric_limits::max()*/) { return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); } Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { int pows[3] = {0, 0, 0}; int sum = 0; while (true) { int presum = sum; // Check all the pows if (dim0 >= (1 << (pows[0] + 1))) { pows[0]++; sum++; } if (sum == 10) { break; } if (dim1 >= (1 << (pows[1] + 1))) { pows[1]++; sum++; } if (sum == 10) { break; } if (dim2 >= (1 << (pows[2] + 1))) { pows[2]++; sum++; } if (sum == presum || sum == pow2) { break; } } return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]); } Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) { // Dims with strides of 0 are ignored as they // correspond to broadcasted dimensions size_t grid_x = 1; size_t grid_y = 1; for (int i = 0; i < shape.size(); ++i) { if (strides[i] == 0) { continue; } if (grid_x * shape[i] < UINT32_MAX) { grid_x *= shape[i]; } else { grid_y *= shape[i]; } } if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { throw std::runtime_error("Unable to safely factor shape."); } if (grid_y > grid_x) { std::swap(grid_x, grid_y); } return std::make_tuple( static_cast(grid_x), static_cast(grid_y), 1); } Dims get_2d_grid_dims_common( const Shape& shape, const Strides& strides, size_t divisor) { // Compute the 2d grid dimensions such that the total size of the grid is // divided by divisor. size_t grid_x = 1; size_t grid_y = 1; for (int i = 0; i < shape.size(); ++i) { if (strides[i] == 0) { continue; } // No need to add this shape we can just remove it from the divisor. if (divisor % shape[i] == 0) { divisor /= shape[i]; continue; } if (grid_x * shape[i] < UINT32_MAX) { grid_x *= shape[i]; } else { grid_y *= shape[i]; } if (divisor > 1) { if (grid_x % divisor == 0) { grid_x /= divisor; divisor = 1; } else if (grid_y % divisor == 0) { grid_y /= divisor; divisor = 1; } } } if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { throw std::runtime_error("Unable to safely factor shape."); } if (grid_y > grid_x) { std::swap(grid_x, grid_y); } if (divisor > 1) { grid_x = ((grid_x + divisor - 1) / divisor) * divisor; } return std::make_tuple( static_cast(grid_x), static_cast(grid_y), 1); } std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2); auto gx = (dim0 + bx - 1) / bx; auto gy = (dim1 + by - 1) / by; auto gz = (dim2 + bz - 1) / bz; return std::make_pair( std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); } } // namespace mlx::core ================================================ FILE: mlx/backend/common/utils.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include "mlx/array.h" namespace mlx::core { // Return the directory that contains current shared library. std::filesystem::path current_binary_dir(); inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); loc += q_and_r.rem * strides[i]; elem = q_and_r.quot; } return loc; } inline int64_t elem_to_loc(int elem, const array& a) { if (a.flags().row_contiguous) { return elem; } return elem_to_loc(elem, a.shape(), a.strides()); } inline Strides make_contiguous_strides(const Shape& shape) { Strides strides(shape.size(), 1); for (int i = shape.size() - 1; i > 0; i--) { strides[i - 1] = strides[i] * shape[i]; } return strides; } // Collapse dims that are contiguous to possibly route to a better kernel // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) // should return {{2, 4}, {{1, 2}}}. // // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, int64_t size_cap = std::numeric_limits::max()); inline std::tuple> collapse_contiguous_dims( const std::vector& xs, size_t size_cap = std::numeric_limits::max()) { std::vector strides; for (auto& x : xs) { strides.emplace_back(x.strides()); } return collapse_contiguous_dims(xs[0].shape(), strides, size_cap); } template > inline auto collapse_contiguous_dims(Arrays&&... xs) { return collapse_contiguous_dims( std::vector{std::forward(xs)...}); } // The single array version of the above. std::pair collapse_contiguous_dims( const Shape& shape, const Strides& strides, int64_t size_cap = std::numeric_limits::max()); std::pair collapse_contiguous_dims( const array& a, int64_t size_cap = std::numeric_limits::max()); // Compute the thread block dimensions which fit the given // input dimensions. // - The thread block dimensions will be powers of two // - The thread block size will be less than 2^pow2 using Dims = std::tuple; Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); // Computes a 2D grid where each element is < UINT_MAX // Assumes: // - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 // - shape and strides correspond to a contiguous (no holes) but // possibly broadcasted array Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); // Same as above but we do an implicit division with divisor. // Basically, equivalent to factorizing // Prod(s \forall s in shape if strides[s] > 0) / divisor. Dims get_2d_grid_dims_common( const Shape& shape, const Strides& strides, size_t divisor); // Get both the block and a grid of blocks that covers dim0, dim1 and dim2. std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); struct ContiguousIterator { inline void step() { int dims = shape_.size(); if (dims == 0) { return; } int i = dims - 1; while (pos_[i] == (shape_[i] - 1) && i > 0) { pos_[i] = 0; loc -= (shape_[i] - 1) * strides_[i]; i--; } pos_[i]++; loc += strides_[i]; } void step(int64_t s) { int dims = shape_.size(); if (dims == 0) { return; } int i = dims - 1; while (s > 0) { if (shape_[i] - pos_[i] > 1) { int steps = static_cast( std::min(static_cast(shape_[i] - pos_[i] - 1), s)); pos_[i] += steps; loc += strides_[i] * steps; s -= steps; } else { while (pos_[i] == (shape_[i] - 1) && i > 0) { pos_[i] = 0; loc -= (shape_[i] - 1) * strides_[i]; i--; } pos_[i]++; loc += strides_[i]; s--; } } } int64_t contiguous_suffix() { if (shape_.size() == 0) { return 0; } return (strides_.back() == 1) ? shape_.back() : 0; } void seek(int64_t n) { loc = 0; for (int i = shape_.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(n, shape_[i]); loc += q_and_r.rem * strides_[i]; pos_[i] = q_and_r.rem; n = q_and_r.quot; } } void reset() { loc = 0; std::fill(pos_.begin(), pos_.end(), 0); } ContiguousIterator() {}; explicit ContiguousIterator(const array& a) : shape_(a.shape()), strides_(a.strides()) { if (!shape_.empty()) { std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); pos_ = Shape(shape_.size(), 0); } } explicit ContiguousIterator( const Shape& shape, const Strides& strides, int dims) : shape_(shape.begin(), shape.begin() + dims), strides_(strides.begin(), strides.begin() + dims) { if (!shape_.empty()) { std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); pos_ = Shape(shape_.size(), 0); } } int64_t loc{0}; private: Shape shape_; Strides strides_; Shape pos_; }; inline auto check_contiguity(const Shape& shape, const Strides& strides) { size_t no_broadcast_data_size = 1; int64_t f_stride = 1; int64_t b_stride = 1; bool is_row_contiguous = true; bool is_col_contiguous = true; for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { is_col_contiguous &= strides[i] == f_stride || shape[i] == 1; is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; f_stride *= shape[i]; b_stride *= shape[ri]; if (strides[i] > 0) { no_broadcast_data_size *= shape[i]; } } return std::make_tuple( no_broadcast_data_size, is_row_contiguous, is_col_contiguous); } inline bool is_donatable(const array& in, const array& out) { constexpr size_t donation_extra = 16384; return in.is_donatable() && in.itemsize() == out.itemsize() && in.buffer_size() <= out.nbytes() + donation_extra; } std::pair prepare_reshape(const array& in, const array& out); void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out); template inline SmallVector remove_index(SmallVector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); return vec; } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/CMakeLists.txt ================================================ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(COMPILER ${CMAKE_C_COMPILER}) set(CLANG TRUE) else() set(COMPILER ${CMAKE_CXX_COMPILER}) endif() set(COMPILE_DEPS ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h ${PROJECT_SOURCE_DIR}/mlx/types/complex.h simd/simd.h simd/base_simd.h simd/math.h simd/type.h unary_ops.h binary_ops.h) if(MSVC) set(SHELL_EXT ps1) set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) else() set(SHELL_EXT sh) set(SHELL_CMD bash) endif() add_custom_command( OUTPUT compiled_preamble.cpp COMMAND ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h ${COMPILE_DEPS}) add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) if(MLX_BUILD_ACCELERATE) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp) endif() if(IOS) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../no_cpu/compiled.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) endif() ================================================ FILE: mlx/backend/cpu/arange.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/cpu/encoder.h" namespace mlx::core { namespace { template void arange(T start, T next, array& out, size_t size, Stream stream) { auto ptr = out.data(); auto step_size = next - start; auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); encoder.dispatch([ptr, start, step_size, size]() mutable { for (int i = 0; i < size; ++i) { ptr[i] = start; start += step_size; } }); } } // namespace } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/arg_reduce.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; Strides strides = remove_index(in.strides(), axis); Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); auto out_ptr = out.data(); for (uint32_t i = 0; i < out.size(); ++i) { auto loc = elem_to_loc(i, shape, strides); auto local_in_ptr = in_ptr + loc; uint32_t ind_v = 0; InT v = (*local_in_ptr); for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { op(j, (*local_in_ptr), &ind_v, &v); } out_ptr[i] = ind_v; } } template void arg_reduce_dispatch( const array& in, array& out, ArgReduce::ReduceType rtype, int axis) { switch (rtype) { case ArgReduce::ArgMin: { auto op = [](auto ind_x, auto x, auto ind_y, auto y) { if (x < (*y)) { (*y) = x; (*ind_y) = ind_x; } }; arg_reduce(in, out, op, axis); break; } case ArgReduce::ArgMax: { auto op = [](auto ind_x, auto x, auto ind_y, auto y) { if (x > (*y)) { (*y) = x; (*ind_y) = ind_x; } }; arg_reduce(in, out, op, axis); break; } } } } // namespace void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), reduce_type_ = reduce_type_, axis_ = axis_]() mutable { switch (in.dtype()) { case bool_: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case uint8: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case uint16: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case uint32: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case uint64: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case int8: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case int16: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case int32: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case int64: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case float16: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case float32: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case bfloat16: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case float64: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; case complex64: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/binary.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/allocator.h" #include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/binary_two.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { void Add::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Add(), stream()); } void DivMod::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); auto& out_a = outputs[0]; auto& out_b = outputs[1]; set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_b, bopt); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out_a = array::unsafe_weak_copy(out_a), out_b = array::unsafe_weak_copy(out_b), bopt]() mutable { auto integral_op = [](auto x, auto y) { return std::make_pair(x / y, x % y); }; auto float_op = [](auto x, auto y) { return std::make_pair(std::trunc(x / y), std::fmod(x, y)); }; switch (out_a.dtype()) { case bool_: binary_op(a, b, out_a, out_b, integral_op, bopt); case uint8: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case uint16: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case uint32: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case uint64: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case int8: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case int16: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case int32: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case int64: binary_op(a, b, out_a, out_b, integral_op, bopt); break; case float16: binary_op(a, b, out_a, out_b, float_op, bopt); break; case float32: binary_op(a, b, out_a, out_b, float_op, bopt); break; case float64: binary_op(a, b, out_a, out_b, float_op, bopt); break; case bfloat16: binary_op(a, b, out_a, out_b, float_op, bopt); break; case complex64: // Should never get here throw std::runtime_error("[DivMod] Complex type not supported"); break; } }); } void Divide::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Divide(), stream()); } void Remainder::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Remainder(), stream()); } void Equal::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; if (equal_nan_) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out = array::unsafe_weak_copy(out), bopt]() mutable { switch (a.dtype()) { case float16: binary_op(a, b, out, bopt); break; case float32: binary_op(a, b, out, bopt); break; case float64: binary_op(a, b, out, bopt); break; case bfloat16: binary_op(a, b, out, bopt); break; case complex64: binary_op(a, b, out, bopt); break; default: throw std::runtime_error( "[NanEqual::eval_cpu] Only for floating point types."); } }); } else { comparison_op_cpu(a, b, out, detail::Equal(), stream()); } } void Greater::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream()); } void GreaterEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op_cpu( inputs[0], inputs[1], out, detail::GreaterEqual(), stream()); } void Less::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream()); } void LessEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream()); } void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream()); } void LogicalAnd::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream()); } void LogicalOr::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream()); } void Maximum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Maximum(), stream()); } void Minimum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Minimum(), stream()); } void Multiply::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Multiply(), stream()); } void NotEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream()); } void Power::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Power(), stream()); } void Subtract::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary_op_cpu(a, b, out, detail::Subtract(), stream()); } void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; switch (op_) { case BitwiseBinary::And: binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream()); break; case BitwiseBinary::Or: binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream()); break; case BitwiseBinary::Xor: binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream()); break; case BitwiseBinary::LeftShift: binary_int_op_cpu(a, b, out, detail::LeftShift(), stream()); break; case BitwiseBinary::RightShift: binary_int_op_cpu(a, b, out, detail::RightShift(), stream()); break; } } void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); const auto& a = inputs[0]; const auto& b = inputs[1]; binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream()); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/binary.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { template struct VectorScalar { template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *b; constexpr int N = simd::max_size; while (size >= N) { simd::store(dst, Op{}(simd::load(a), simd::Simd(scalar))); dst += N; a += N; size -= N; } while (size-- > 0) { *dst = Op{}(*a, scalar); dst++; a++; } } }; template struct ScalarVector { template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *a; constexpr int N = simd::max_size; while (size >= N) { simd::store(dst, Op{}(simd::Simd(scalar), simd::load(b))); dst += N; b += N; size -= N; } while (size-- > 0) { *dst = Op{}(scalar, *b); dst++; b++; } } }; template struct VectorVector { template void operator()(const T* a, const T* b, U* dst, int size) { constexpr int N = simd::max_size; while (size >= N) { simd::store(dst, Op{}(simd::load(a), simd::load(b))); dst += N; a += N; b += N; size -= N; } while (size-- > 0) { *dst = Op{}(*a, *b); dst++; a++; b++; } } }; template void binary_op_dims( const T* a, const T* b, U* out, const Shape& shape, const Strides& a_strides, const Strides& b_strides, const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; auto stride_out = out_strides[axis]; auto N = shape[axis]; for (int i = 0; i < N; i++) { if constexpr (D > 1) { binary_op_dims( a, b, out, shape, a_strides, b_strides, out_strides, axis + 1); } else { if constexpr (Strided) { Op{}(a, b, out, stride_out); } else { *out = Op{}(*a, *b); } } out += stride_out; a += stride_a; b += stride_b; } } template void binary_op_dispatch_dims( const T* a, const T* b, U* out, int dim, int size, const Shape& shape, const Strides& a_strides, const Strides& b_strides, const Strides& out_strides) { switch (dim) { case 1: binary_op_dims( a, b, out, shape, a_strides, b_strides, out_strides, 0); return; case 2: binary_op_dims( a, b, out, shape, a_strides, b_strides, out_strides, 0); return; case 3: binary_op_dims( a, b, out, shape, a_strides, b_strides, out_strides, 0); return; } ContiguousIterator a_it(shape, a_strides, dim - 3); ContiguousIterator b_it(shape, b_strides, dim - 3); auto stride = out_strides[dim - 4]; for (int64_t elem = 0; elem < size; elem += stride) { binary_op_dims( a + a_it.loc, b + b_it.loc, out + elem, shape, a_strides, b_strides, out_strides, dim - 3); a_it.step(); b_it.step(); } } template void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { // The full computation is scalar scalar so call the base op once auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_ptr = out.data(); if (bopt == BinaryOpType::ScalarScalar) { *out_ptr = Op{}(*a_ptr, *b_ptr); return; } // The full computation is scalar vector so delegate to the op if (bopt == BinaryOpType::ScalarVector) { ScalarVector{}(a_ptr, b_ptr, out_ptr, b.data_size()); return; } // The full computation is vector scalar so delegate to the op if (bopt == BinaryOpType::VectorScalar) { VectorScalar{}(a_ptr, b_ptr, out_ptr, a.data_size()); return; } // The full computation is vector vector so delegate to the op if (bopt == BinaryOpType::VectorVector) { VectorVector{}(a_ptr, b_ptr, out_ptr, a.size()); return; } // General computation so let's try to optimize auto [new_shape, new_strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), out.strides()}); auto& a_strides = new_strides[0]; auto& b_strides = new_strides[1]; auto& strides = new_strides[2]; // Get the left-most dim such that the array is row contiguous after auto leftmost_rc_dim = [&strides](const auto& arr_strides) { int d = arr_strides.size() - 1; for (; d >= 0 && arr_strides[d] == strides[d]; d--) { } return d + 1; }; auto a_rc_dim = leftmost_rc_dim(a_strides); auto b_rc_dim = leftmost_rc_dim(b_strides); // Get the left-most dim such that the array is a broadcasted "scalar" after auto leftmost_s_dim = [](const auto& arr_strides) { int d = arr_strides.size() - 1; for (; d >= 0 && arr_strides[d] == 0; d--) { } return d + 1; }; auto a_s_dim = leftmost_s_dim(a_strides); auto b_s_dim = leftmost_s_dim(b_strides); auto ndim = new_shape.size(); // Case 1: LxM and FxM where L and F are broadcastable and M is row // contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { bopt = BinaryOpType::ScalarVector; dim = d; } // Can be sure dim > 0 since otherwise we would have used one of the fully // contiguous methods above. Except for the case that the flags do not // correspond to the underlying contiguity. if (dim == 0 || strides[dim - 1] < 16) { bopt = BinaryOpType::General; dim = ndim; } switch (bopt) { case BinaryOpType::VectorVector: binary_op_dispatch_dims>( a_ptr, b_ptr, out_ptr, dim, a.size(), new_shape, a_strides, b_strides, strides); break; case BinaryOpType::VectorScalar: binary_op_dispatch_dims>( a_ptr, b_ptr, out_ptr, dim, a.size(), new_shape, a_strides, b_strides, strides); break; case BinaryOpType::ScalarVector: binary_op_dispatch_dims>( a_ptr, b_ptr, out_ptr, dim, a.size(), new_shape, a_strides, b_strides, strides); break; default: binary_op_dispatch_dims( a_ptr, b_ptr, out_ptr, dim, a.size(), new_shape, a_strides, b_strides, strides); break; } } template void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { binary_op(a, b, out, bopt); } template void binary_op_cpu( const array& a, const array& b, array& out, Op op, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out = array::unsafe_weak_copy(out), bopt]() mutable { switch (out.dtype()) { case bool_: binary_op(a, b, out, bopt); break; case uint8: binary_op(a, b, out, bopt); break; case uint16: binary_op(a, b, out, bopt); break; case uint32: binary_op(a, b, out, bopt); break; case uint64: binary_op(a, b, out, bopt); break; case int8: binary_op(a, b, out, bopt); break; case int16: binary_op(a, b, out, bopt); break; case int32: binary_op(a, b, out, bopt); break; case int64: binary_op(a, b, out, bopt); break; case float16: binary_op(a, b, out, bopt); break; case float32: binary_op(a, b, out, bopt); break; case float64: binary_op(a, b, out, bopt); break; case bfloat16: binary_op(a, b, out, bopt); break; case complex64: binary_op(a, b, out, bopt); break; } }); } template void comparison_op_cpu( const array& a, const array& b, array& out, Op op, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out = array::unsafe_weak_copy(out), bopt]() mutable { switch (a.dtype()) { case bool_: binary_op(a, b, out, bopt); break; case uint8: binary_op(a, b, out, bopt); break; case uint16: binary_op(a, b, out, bopt); break; case uint32: binary_op(a, b, out, bopt); break; case uint64: binary_op(a, b, out, bopt); break; case int8: binary_op(a, b, out, bopt); break; case int16: binary_op(a, b, out, bopt); break; case int32: binary_op(a, b, out, bopt); break; case int64: binary_op(a, b, out, bopt); break; case float16: binary_op(a, b, out, bopt); break; case float32: binary_op(a, b, out, bopt); break; case float64: binary_op(a, b, out, bopt); break; case bfloat16: binary_op(a, b, out, bopt); break; case complex64: binary_op(a, b, out, bopt); break; } }); } template void binary_float_op_cpu( const array& a, const array& b, array& out, Op op, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out = array::unsafe_weak_copy(out), bopt]() mutable { switch (out.dtype()) { case float16: binary_op(a, b, out, bopt); break; case float32: binary_op(a, b, out, bopt); break; case float64: binary_op(a, b, out, bopt); break; case bfloat16: binary_op(a, b, out, bopt); break; case complex64: binary_op(a, b, out, bopt); break; default: throw std::runtime_error( "[binary_float] Only supports floating point types."); } }); } template void binary_int_op_cpu( const array& a, const array& b, array& out, Op op, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), out = array::unsafe_weak_copy(out), bopt]() mutable { switch (out.dtype()) { case bool_: binary_op(a, b, out, bopt); case uint8: binary_op(a, b, out, bopt); break; case uint16: binary_op(a, b, out, bopt); break; case uint32: binary_op(a, b, out, bopt); break; case uint64: binary_op(a, b, out, bopt); break; case int8: binary_op(a, b, out, bopt); break; case int16: binary_op(a, b, out, bopt); break; case int32: binary_op(a, b, out, bopt); break; case int64: binary_op(a, b, out, bopt); break; default: throw std::runtime_error("[binary_int] Type not supported"); break; } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/binary_ops.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core::detail { using namespace mlx::core::simd; #define BINARY_SINGLE() \ template \ T operator()(T x, T y) { \ return (*this)(Simd(x), Simd(y)).value; \ } #define DEFAULT_BINARY_OP(Op, op) \ struct Op { \ template \ Simd operator()(Simd x, Simd y) { \ return op(x, y); \ } \ BINARY_SINGLE() \ }; DEFAULT_BINARY_OP(Add, operator+) DEFAULT_BINARY_OP(ArcTan2, atan2) DEFAULT_BINARY_OP(Divide, operator/) DEFAULT_BINARY_OP(Multiply, operator*) DEFAULT_BINARY_OP(Subtract, operator-) DEFAULT_BINARY_OP(LogicalAnd, operator&&) DEFAULT_BINARY_OP(LogicalOr, operator||) DEFAULT_BINARY_OP(BitwiseAnd, operator&) DEFAULT_BINARY_OP(BitwiseOr, operator|) DEFAULT_BINARY_OP(BitwiseXor, operator^) DEFAULT_BINARY_OP(LeftShift, operator<<) DEFAULT_BINARY_OP(RightShift, operator>>) DEFAULT_BINARY_OP(Remainder, remainder) DEFAULT_BINARY_OP(Maximum, maximum) DEFAULT_BINARY_OP(Minimum, minimum) DEFAULT_BINARY_OP(Power, pow) #define DEFAULT_BOOL_OP(Op, op) \ struct Op { \ template \ Simd operator()(Simd x, Simd y) { \ return op(x, y); \ } \ template \ bool operator()(T x, T y) { \ return (*this)(Simd(x), Simd(y)).value; \ } \ }; DEFAULT_BOOL_OP(Equal, operator==) DEFAULT_BOOL_OP(Greater, operator>) DEFAULT_BOOL_OP(GreaterEqual, operator>=) DEFAULT_BOOL_OP(Less, operator<) DEFAULT_BOOL_OP(LessEqual, operator<=) DEFAULT_BOOL_OP(NotEqual, operator!=) struct NaNEqual { template Simd operator()(Simd x, Simd y) { return x == y || (isnan(x) && isnan(y)); } template bool operator()(T x, T y) { return (*this)(Simd(x), Simd(y)).value; } }; struct LogAddExp { template Simd operator()(Simd x, Simd y) { auto maxval = maximum(x, y); auto minval = minimum(x, y); auto mask = minval == -inf || maxval == inf; auto out = maxval + log1p(exp(minval - maxval)); return select(mask, Simd(maxval), Simd(out)); } BINARY_SINGLE() }; struct Select { template T operator()(bool condition, T x, T y) { return (*this)(Simd(condition), Simd(x), Simd(y)) .value; } template Simd operator()(Simd condition, Simd x, Simd y) { return select(condition, x, y); } }; } // namespace mlx::core::detail ================================================ FILE: mlx/backend/cpu/binary_two.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/binary.h" namespace mlx::core { namespace { template void binary_op_dims( const T* a, const T* b, U* out_a, U* out_b, Op op, const Shape& shape, const Strides& a_strides, const Strides& b_strides, const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; auto stride_out = out_strides[axis]; auto N = shape[axis]; for (int i = 0; i < N; i++) { if constexpr (D > 1) { binary_op_dims( a, b, out_a, out_b, op, shape, a_strides, b_strides, out_strides, axis + 1); } else { std::tie(*out_a, *out_b) = op(*a, *b); } a += stride_a; b += stride_b; out_a += stride_out; out_b += stride_out; } } template void binary_op_dispatch_dims( const array& a, const array& b, array& out_a, array& out_b, Op op) { auto [shape, strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), out_a.strides()}); const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* out_a_ptr = out_a.data(); U* out_b_ptr = out_b.data(); const auto& a_strides = strides[0]; const auto& b_strides = strides[1]; const auto& out_strides = strides[2]; int ndim = shape.size(); switch (ndim) { case 1: binary_op_dims( a_ptr, b_ptr, out_a_ptr, out_b_ptr, op, shape, a_strides, b_strides, out_strides, 0); return; case 2: binary_op_dims( a_ptr, b_ptr, out_a_ptr, out_b_ptr, op, shape, a_strides, b_strides, out_strides, 0); return; } ContiguousIterator a_it(shape, a_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2); auto stride = out_strides[ndim - 3]; for (size_t elem = 0; elem < a.size(); elem += stride) { binary_op_dims( a_ptr + a_it.loc, b_ptr + b_it.loc, out_a_ptr + elem, out_b_ptr + elem, op, shape, a_strides, b_strides, out_strides, ndim - 2); a_it.step(); b_it.step(); } } template void binary_op( const array& a, const array& b, array& out_a, array& out_b, Op op, BinaryOpType bopt) { // The full computation is scalar scalar so call the base op once if (bopt == BinaryOpType::General) { binary_op_dispatch_dims(a, b, out_a, out_b, op); return; } auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_a_ptr = out_a.data(); auto out_b_ptr = out_b.data(); if (bopt == BinaryOpType::ScalarScalar) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); } else if (bopt == BinaryOpType::ScalarVector) { for (size_t i = 0; i < b.data_size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; b_ptr++; } } else if (bopt == BinaryOpType::VectorScalar) { for (size_t i = 0; i < a.data_size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; a_ptr++; } } else { // VectorVector for (size_t i = 0; i < a.size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; a_ptr++; b_ptr++; } } } } // namespace } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/cholesky.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" namespace mlx::core { template void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: // (A)ᵀ = A // and that a column-major lower triangular matrix is a row-major upper // triangular matrix, so uplo is the opposite of what we would expect from // upper // The decomposition is computed in place, so just copy the input to the // output. copy_cpu( a, factor, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(factor); encoder.dispatch([matrix = factor.data(), upper, N = a.shape(-1), size = a.size()]() mutable { char uplo = (upper) ? 'L' : 'U'; size_t num_matrices = size / (N * N); for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. int info; potrf( /* uplo = */ &uplo, /* n = */ &N, /* a = */ matrix, /* lda = */ &N, /* info = */ &info); // TODO: We do nothing when the matrix is not positive semi-definite // because throwing an error would result in a crash. If we figure out how // to catch errors from the implementation we should throw. if (info < 0) { std::stringstream msg; msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code " << info; throw std::runtime_error(msg.str()); } // Zero out the upper/lower triangle while advancing the pointer to the // next matrix at the same time. for (int row = 0; row < N; row++) { if (upper) { std::fill(matrix, matrix + row, 0); } else { std::fill(matrix + row + 1, matrix + N, 0); } matrix += N; } } }); } void Cholesky::eval_cpu(const std::vector& inputs, array& output) { switch (inputs[0].dtype()) { case float32: cholesky_impl(inputs[0], output, upper_, stream()); break; case float64: cholesky_impl(inputs[0], output, upper_, stream()); break; default: throw std::runtime_error( "[Cholesky::eval_cpu] only supports float32 or float64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/compiled.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include #include #include #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/cpu/compiled_preamble.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" #include "mlx/version.h" namespace mlx::core { struct CompilerCache { struct DLib { DLib(const std::string& libname) { lib = dlopen(libname.c_str(), RTLD_NOW); if (!lib) { std::ostringstream msg; msg << "Could not load C++ shared library " << dlerror(); throw std::runtime_error(msg.str()); } } ~DLib() { dlclose(lib); } void* lib; }; // Statics to cache compiled libraries and functions std::list libs; std::unordered_map kernels; std::shared_mutex mtx; }; static CompilerCache& cache() { static CompilerCache cache_; return cache_; }; // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. namespace detail { bool compile_available_for_device(const Device& device) { return true; } } // namespace detail // Return a pointer to a compiled function void* compile( const std::string& kernel_name, const std::function& source_builder) { { std::shared_lock lock(cache().mtx); if (auto it = cache().kernels.find(kernel_name); it != cache().kernels.end()) { return it->second; } } std::unique_lock lock(cache().mtx); if (auto it = cache().kernels.find(kernel_name); it != cache().kernels.end()) { return it->second; } std::string source_code = source_builder(); std::string kernel_file_name; // Deal with long kernel names. Maximum length for filename on macOS is 255 // characters, and on Windows the maximum length for whole path is 260. Clip // file name with a little extra room and append a 16 character hash. #ifdef _WIN32 constexpr int max_file_name_length = 140; #else constexpr int max_file_name_length = 245; #endif if (kernel_name.size() > max_file_name_length) { std::ostringstream file_name; file_name << std::string_view(kernel_name).substr(0, max_file_name_length - 16); auto file_id = std::hash{}(kernel_name.substr(max_file_name_length - 16)); file_name << "_" << std::hex << std::setw(16) << file_id << std::dec; kernel_file_name = file_name.str(); } else { kernel_file_name = kernel_name; } auto output_dir = std::filesystem::temp_directory_path() / "mlx" / version() / "cpu"; if (!std::filesystem::exists(output_dir)) { std::filesystem::create_directories(output_dir); } std::string shared_lib_name = "lib" + kernel_file_name + ".so"; auto shared_lib_path = (output_dir / shared_lib_name).string(); bool lib_exists = false; { std::ifstream f(shared_lib_path.c_str()); lib_exists = f.good(); } if (!lib_exists) { // Open source file and write source code to it std::string source_file_name = kernel_file_name + ".cpp"; auto source_file_path = (output_dir / source_file_name).string(); std::ofstream source_file(source_file_path); source_file << source_code; source_file.close(); try { JitCompiler::exec( JitCompiler::build_command( output_dir, source_file_name, shared_lib_name)); } catch (const std::exception& error) { throw std::runtime_error( fmt::format( "[Compile::eval_cpu] Failed to compile function {0}: {1}", kernel_name, error.what())); } } // load library cache().libs.emplace_back(shared_lib_path); // Load function void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " << kernel_name << std::endl << dlerror(); throw std::runtime_error(msg.str()); } cache().kernels.insert({kernel_name, fun}); return fun; } inline void build_kernel( std::ostream& os, const std::string& kernel_name, const std::vector& inputs, const std::vector& outputs, const std::vector& tape, const std::function& is_constant, bool contiguous, int ndim) { NodeNamer namer; #ifdef _MSC_VER // Export the symbol os << "__declspec(dllexport) "; #endif // Start the kernel os << "void " << kernel_name << "(int* shape, int64_t** strides, void** args) {" << std::endl; // Add the input arguments int cnt = 0; int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list if (is_constant(i)) { continue; } const auto& x = inputs[i]; auto& xname = namer.get_name(x); auto tstr = get_type_string(x.dtype()); os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { os << " const int64_t* " << xname << "_strides = strides[" << strides_index++ << "];" << std::endl; } } // Add the output arguments for (auto& x : outputs) { auto tstr = get_type_string(x.dtype()); os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; } // Add output size if (contiguous) { os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; } if (contiguous) { os << " for (size_t i = 0; i < size; ++i) {" << std::endl; } else { for (int d = 0; d < ndim; ++d) { os << " for (int i" << d << " = 0; i" << d << " < shape[" << d << "]; ++i" << d << ") {" << std::endl; } } // Read the inputs in tmps for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; auto& xname = namer.get_name(x); if (is_constant(i)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; print_constant(os, x); os << ";" << std::endl; } else if (is_scalar(x)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " << xname << "[0];" << std::endl; } else if (contiguous) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " << xname << "[i];" << std::endl; } else { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *" << xname << ";" << std::endl; } } // Actually write the computation for (auto& x : tape) { os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) << " = "; if (is_static_cast(x.primitive())) { os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" << namer.get_name(x.inputs()[0]) << ");" << std::endl; } else { os << x.primitive().name(); os << "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; } os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; } } // Write the outputs from tmps for (auto& x : outputs) { if (contiguous) { os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x) << ";" << std::endl; } else { os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x) << ";" << std::endl; } } // Close loops if (contiguous) { os << " }" << std::endl; } else { for (int d = ndim - 1; d >= 0; --d) { // Update pointers for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; if (is_constant(i) || is_scalar(x)) { continue; } auto& xname = namer.get_name(x); os << " " << xname << " += " << xname << "_strides[" << d << "];" << std::endl; if (d < ndim - 1) { os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]" << " * shape[" << d + 1 << "];" << std::endl; } } os << " }" << std::endl; } } // Finish the kernel os << "}" << std::endl; } void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto& encoder = cpu::get_command_encoder(stream()); // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); // Collect function input arguments. std::vector args; for (size_t i = 0; i < inputs.size(); ++i) { if (is_constant_(i)) { continue; } const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); } // Get the kernel name from the lib int ndim = shape.size(); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); if (!contiguous) { kernel_name += std::to_string(ndim); } // Get the function auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { std::ostringstream kernel; kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; build_kernel( kernel, kernel_name, inputs_, outputs_, tape_, is_constant_, contiguous, ndim); // Close extern "C" kernel << "}" << std::endl; return kernel.str(); }); compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { args.push_back(x.data()); encoder.set_output_array(x); } if (contiguous) { args.push_back((void*)outputs[0].data_size()); } auto fun = reinterpret_cast(fn_ptr); encoder.dispatch([fun, args = std::move(args), strides = std::move(strides), shape = std::move(shape)]() mutable { SmallVector strides_ptrs; for (auto& s : strides) { strides_ptrs.push_back(s.data()); } fun(shape.data(), strides_ptrs.data(), args.data()); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/compiled_preamble.h ================================================ // Copyright © 2023-24 Apple Inc. #pragma once // clang-format off #include "mlx/types/half_types.h" #include "mlx/types/complex.h" #include "mlx/backend/cpu/unary_ops.h" #include "mlx/backend/cpu/binary_ops.h" // clang-format on const char* get_kernel_preamble(); ================================================ FILE: mlx/backend/cpu/conv.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { /////////////////////////////////////////////////////////////////////////////// // Naive reference conv /////////////////////////////////////////////////////////////////////////////// template void slow_conv_1D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(in); encoder.set_input_array(wt); encoder.set_output_array(out); encoder.dispatch([start_wt_ptr = wt.data(), in_ptr = in.data(), out_ptr = out.data(), N = in.shape( 0), // Batch size, should be the same as out.shape(0) iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim oH = out.shape(1), // Output spatial dim wH = wt.shape(1), // Weight spatial dim groups = in.shape(2) / wt.shape(2), O = wt.shape(0), // Out channels C_per_group = wt.shape(2), in_stride_N = in.strides()[0], in_stride_H = in.strides()[1], in_stride_C = in.strides()[2], wt_stride_O = wt.strides()[0], wt_stride_H = wt.strides()[1], wt_stride_C = wt.strides()[2], out_stride_N = out.strides()[0], out_stride_H = out.strides()[1], out_stride_O = out.strides()[2], flip, padding_lo = padding_lo[0], padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { auto O_per_group = O / groups; for (int n = 0; n < N; ++n) { for (int oh = 0; oh < oH; ++oh) { for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; float r = 0.; for (int wh = 0; wh < wH; ++wh) { const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); if (ih >= 0 && ih < iH && ih_div.rem == 0) { for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast( in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * static_cast( wt_ptr[(c % C_per_group) * wt_stride_C]); } // c } // ih check } // wh out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); } // o } // g } // oh in_ptr += in_stride_N; out_ptr += out_stride_N; } // n }); } template void slow_conv_2D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(in); encoder.set_input_array(wt); encoder.set_output_array(out); encoder.dispatch( [st_wt_ptr = wt.data(), st_in_ptr = in.data(), st_out_ptr = out.data(), N = in.shape(0), // Batch size, should be the same as out.shape(0) iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim C = in.shape(3), // In channels oH = out.shape(1), // Output spatial dim oW = out.shape(2), // Output spatial dim O = wt.shape(0), // Out channels wH = wt.shape(1), // Weight spatial dim wW = wt.shape(2), // Weight spatial dim groups = in.shape(3) / wt.shape(3), C_per_group = wt.shape(3), in_stride_N = in.strides()[0], in_stride_H = in.strides()[1], in_stride_W = in.strides()[2], in_stride_C = in.strides()[3], wt_stride_O = wt.strides()[0], wt_stride_H = wt.strides()[1], wt_stride_W = wt.strides()[2], wt_stride_C = wt.strides()[3], out_stride_N = out.strides()[0], out_stride_H = out.strides()[1], out_stride_W = out.strides()[2], out_stride_O = out.strides()[3], padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip]() mutable { bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; const int O_per_group = O / groups; auto pt_conv_no_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; int ih_base = oh * wt_strides[0] - padding_lo[0]; int iw_base = ow * wt_strides[1] - padding_lo[1]; for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { float r = 0.; for (int wh = 0; wh < wH; ++wh) { for (int ww = 0; ww < wW; ++ww) { int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int ih = ih_base + wh_flip * wt_dilation[0]; int iw = iw_base + ww_flip * wt_dilation[1]; const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast(in_ptr_pt[c * in_stride_C]) * static_cast( wt_ptr_pt[(c % C_per_group) * wt_stride_C]); } // c } // ww } // wh out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o } // g }; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; std::vector base_h(f_out_jump_h); std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_h; ++i) { int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[0] != 0) { wh_base++; ih_loop += jump_h; } base_h[i] = wh_base; } for (int j = 0; j < f_out_jump_w; ++j) { int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[1] != 0) { ww_base++; iw_loop += jump_w; } base_w[j] = ww_base; } auto pt_conv_all_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; int ih_base = oh * wt_strides[0] - padding_lo[0]; int iw_base = ow * wt_strides[1] - padding_lo[1]; int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { float r = 0.; for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int ih = ih_base + wh_flip * wt_dilation[0]; int iw = iw_base + ww_flip * wt_dilation[1]; if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast(in_ptr_pt[c * in_stride_C]) * static_cast( wt_ptr_pt[(c % C_per_group) * wt_stride_C]); } // c } // ih, iw check } // ww } // wh out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o } // g }; int oH_border_0 = 0; int oH_border_1 = is_idil_one ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; int oH_border_2 = std::max( oH_border_1, (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_3 = oH; int oW_border_0 = 0; int oW_border_1 = is_idil_one ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; int oW_border_2 = std::max( oW_border_1, (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { // Case 1: oh might put us out of bounds for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh // Case 2: oh in bounds for (int oh = oH_border_1; oh < oH_border_2; ++oh) { // Case a: ow might put us out of bounds for (int ow = oW_border_0; ow < oW_border_1; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow // Case b: ow in bounds for (int ow = oW_border_1; ow < oW_border_2; ++ow) { pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow // Case c: ow might put us out of bounds for (int ow = oW_border_2; ow < oW_border_3; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh // Case 3: oh might put us out of bounds for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh st_in_ptr += in_stride_N; st_out_ptr += out_stride_N; } // n }); } template void slow_conv_3D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(in); encoder.set_input_array(wt); encoder.set_output_array(out); encoder.dispatch([st_wt_ptr = wt.data(), st_in_ptr = in.data(), st_out_ptr = out.data(), N = in.shape( 0), // Batch size, should be the same as out.shape(0) iD = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim iH = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim iW = 1 + in_dilation[2] * (in.shape(3) - 1), // Input spatial dim oD = out.shape(1), // Output spatial dim oH = out.shape(2), // Output spatial dim oW = out.shape(3), // Output spatial dim O = wt.shape(0), // Out channels C = wt.shape(4), // In channels wD = wt.shape(1), // Weight spatial dim wH = wt.shape(2), // Weight spatial dim wW = wt.shape(3), // Weight spatial dim in_stride_N = in.strides()[0], in_stride_D = in.strides()[1], in_stride_H = in.strides()[2], in_stride_W = in.strides()[3], in_stride_C = in.strides()[4], wt_stride_O = wt.strides()[0], wt_stride_D = wt.strides()[1], wt_stride_H = wt.strides()[2], wt_stride_W = wt.strides()[3], wt_stride_C = wt.strides()[4], out_stride_N = out.strides()[0], out_stride_D = out.strides()[1], out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip]() mutable { bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1; auto pt_conv_no_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int od, int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; int id_base = od * wt_strides[0] - padding_lo[0]; int ih_base = oh * wt_strides[1] - padding_lo[1]; int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; for (int wd = 0; wd < wD; ++wd) { for (int wh = 0; wh < wH; ++wh) { for (int ww = 0; ww < wW; ++ww) { int wd_flip = flip ? wD - wd - 1 : wd; int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int id = id_base + wd_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[2]; const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; const T* in_ptr_pt = in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * static_cast(wt_ptr_pt[0]); in_ptr_pt += in_stride_C; wt_ptr_pt += wt_stride_C; } // c } // ww } // wh } // wd out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o }; int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[1] : wt_dilation[1]; int jump_w = flip ? -wt_dilation[2] : wt_dilation[2]; int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0); int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0); int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2]; int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2]; std::vector base_d(f_out_jump_d); std::vector base_h(f_out_jump_h); std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { wd_base++; id_loop += jump_d; } base_d[i] = wd_base; } for (int i = 0; i < f_out_jump_h; ++i) { int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { wh_base++; ih_loop += jump_h; } base_h[i] = wh_base; } for (int j = 0; j < f_out_jump_w; ++j) { int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { ww_base++; iw_loop += jump_w; } base_w[j] = ww_base; } auto pt_conv_all_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int od, int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; int id_base = od * wt_strides[0] - padding_lo[0]; int ih_base = oh * wt_strides[1] - padding_lo[1]; int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; for (int o = 0; o < O; ++o) { float r = 0.; for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { int wd_flip = flip ? wD - wd - 1 : wd; int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int id = id_base + wd_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[2]; if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && iw < iW) { const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + ih_dil * in_stride_H + iw_dil * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * static_cast(wt_ptr_pt[0]); in_ptr_pt += in_stride_C; wt_ptr_pt += wt_stride_C; } // c } // iD, ih, iw check } // ww } // wh } // wd out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o }; int oD_border_0 = 0; int oD_border_1 = is_idil_one ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; int oD_border_2 = std::max( oD_border_1, (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; int oH_border_1 = is_idil_one ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; int oH_border_2 = std::max( oH_border_1, (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; int oW_border_1 = is_idil_one ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; int oW_border_2 = std::max( oW_border_1, (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { // Case 1: od might put us out of bounds for (int od = oD_border_0; od < oD_border_1; ++od) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od // Case 2: od in bounds for (int od = oD_border_1; od < oD_border_2; ++od) { // Case 2.1: oh might put us out of bounds for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh // Case 2.2: oh in bounds for (int oh = oH_border_1; oh < oH_border_2; ++oh) { // Case 2.2.1: ow might put us out of bounds for (int ow = oW_border_0; ow < oW_border_1; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow // Case 2.2.2: ow in bounds for (int ow = oW_border_1; ow < oW_border_2; ++ow) { pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow // Case 2.2.3: ow might put us out of bounds for (int ow = oW_border_2; ow < oW_border_3; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh // Case 2.3: oh might put us out of bounds for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od // Case 3: od might put us out of bounds for (int od = oD_border_2; od < oD_border_3; ++od) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od st_in_ptr += in_stride_N; st_out_ptr += out_stride_N; } // n }); } void dispatch_slow_conv_1D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { if (in.dtype() == float32) { return slow_conv_1D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == float16) { return slow_conv_1D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == bfloat16) { return slow_conv_1D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } void dispatch_slow_conv_2D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { if (in.dtype() == float32) { return slow_conv_2D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == float16) { return slow_conv_2D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == bfloat16) { return slow_conv_2D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } void dispatch_slow_conv_3D( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { if (in.dtype() == float32) { return slow_conv_3D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == float16) { return slow_conv_3D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else if (in.dtype() == bfloat16) { return slow_conv_3D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } /////////////////////////////////////////////////////////////////////////////// // Explicit gemm conv /////////////////////////////////////////////////////////////////////////////// template void flip_spatial_dims_inplace( T* x, size_t in_channels, size_t out_channels, size_t spatial_size) { for (size_t i = 0; i < out_channels; i++) { T* top = x + i * spatial_size * in_channels; T* bottom = x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels; for (size_t j = 0; j < spatial_size / 2; j++) { for (size_t k = 0; k < in_channels; k++) { std::swap(top[k], bottom[k]); } top += in_channels; bottom -= in_channels; } } } void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim const int C = in.shape(2); // Input channels const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels const int wH = wt.shape(1); // Weight spatial dim const int groups = C / wt.shape(2); const int C_per_group = wt.shape(2); const int O_per_group = O / groups; auto conv_dtype = float32; auto& encoder = cpu::get_command_encoder(stream); // Pad input Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros std::vector temps; temps.push_back(array(0, conv_dtype)); copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); temps.push_back(in_padded_slice); // Make strided view Shape strided_shape = {N, oH, wH, C}; Strides strided_strides = { in_padded.strides()[0], in_padded.strides()[1] * wt_strides[0], in_padded.strides()[1], in_padded.strides()[2]}; auto flags = in_padded.flags(); if (groups > 1) { // Transpose the last two dimensions for grouped convolutions std::swap(strided_shape[2], strided_shape[3]); std::swap(strided_strides[2], strided_strides[3]); } array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view Shape strided_reshape = {N * oH, wH * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy_cpu(in_strided_view, in_strided, CopyType::General, stream); temps.push_back(in_strided); // Check wt dtype and prepare auto gemm_wt = wt; auto gemm_out = out; if (groups > 1) { // Transpose the last two dimensions for grouped convolutions array wt_transpose( {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {}); wt_transpose.copy_shared_buffer( wt, {wt.strides(0), wt.strides(2), wt.strides(1)}, wt.flags(), wt.size(), 0); gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream); temps.push_back(gemm_wt); } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); copy_cpu(wt, gemm_wt, ctype, stream); temps.push_back(gemm_wt); } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } encoder.set_input_array(in_strided); encoder.set_input_array(gemm_wt); encoder.set_output_array(gemm_out); encoder.dispatch([in_strided_ptr = in_strided.data(), gemm_wt_ptr = gemm_wt.data(), gemm_out_ptr = gemm_out.data(), groups, strided_reshape = strided_reshape[0], O, C, wH, O_per_group, C_per_group]() { for (int g = 0; g < groups; ++g) { // Perform gemm cblas_sgemm( CblasRowMajor, CblasNoTrans, // no trans A CblasTrans, // transB strided_reshape, // M O_per_group, // N C_per_group * wH, // K 1.0f, // alpha in_strided_ptr + g * C_per_group * wH, // A wH * C, // lda gemm_wt_ptr + g * O_per_group * C_per_group * wH, // B wH * C_per_group, // ldb 0.0f, // beta gemm_out_ptr + g * O_per_group, // C O // ldc ); } }); // Copy results if needed if (out.dtype() != float32) { copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); } encoder.add_temporaries(std::move(temps)); } void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const auto iDim = Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim const auto oDim = Shape( out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(-1); // In channels const auto wDim = Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim auto conv_dtype = float32; auto& encoder = cpu::get_command_encoder(stream); // Pad input Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros std::vector temps = {array(0, conv_dtype)}; copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = 0; for (size_t i = 0; i < padding_lo.size(); i++) { data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); temps.push_back(in_padded_slice); // Make strided view Shape strided_shape(oDim.size() + wDim.size() + 2); strided_shape.front() = N; for (size_t i = 0; i < oDim.size(); i++) { strided_shape[i + 1] = oDim[i]; } for (size_t i = 0; i < wDim.size(); i++) { strided_shape[i + 1 + oDim.size()] = wDim[i]; } strided_shape.back() = C; Strides strided_strides(in.shape().size() * 2 - 2); strided_strides[0] = in_padded.strides()[0]; for (size_t i = 0; i < wt_strides.size(); i++) { strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; } for (size_t i = 1; i < in_padded.strides().size(); i++) { strided_strides[i + wt_strides.size()] = in_padded.strides()[i]; } auto flags = in_padded.flags(); array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view Shape strided_reshape = {N, C}; for (const auto& o : oDim) { strided_reshape[0] *= o; } for (const auto& w : wDim) { strided_reshape[1] *= w; } array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy_cpu(in_strided_view, in_strided, CopyType::General, stream); temps.push_back(in_strided); // Check wt dtype and prepare auto gemm_wt = wt; auto gemm_out = out; if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); copy_cpu(wt, gemm_wt, ctype, stream); temps.push_back(gemm_wt); } if (flip) { auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream); temps.push_back(gemm_wt_); // Calculate the total size of the spatial dimensions int spatial_size = 1; for (int d = 1; d < gemm_wt.ndim() - 1; ++d) { spatial_size *= gemm_wt.shape(d); } encoder.set_output_array(gemm_wt_); encoder.dispatch([gemm_wt_ptr = gemm_wt_.data(), out_channels = gemm_wt.shape(0), in_channels = gemm_wt.shape(-1), spatial_size]() { flip_spatial_dims_inplace( gemm_wt_ptr, in_channels, out_channels, spatial_size); }); gemm_wt = gemm_wt_; } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } encoder.set_input_array(in_strided); encoder.set_input_array(gemm_wt); encoder.set_output_array(gemm_out); encoder.dispatch([in_strided_ptr = in_strided.data(), gemm_wt_ptr = gemm_wt.data(), gemm_out_ptr = gemm_out.data(), strided_reshape = std::move(strided_reshape), O]() { // Perform gemm cblas_sgemm( CblasRowMajor, CblasNoTrans, // no trans A CblasTrans, // transB strided_reshape[0], // M O, // N strided_reshape[1], // K 1.0f, // alpha in_strided_ptr, strided_reshape[1], // lda gemm_wt_ptr, strided_reshape[1], // ldb 0.0f, // beta gemm_out_ptr, O // ldc ); }); // Copy results if needed if (out.dtype() != float32) { copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); } encoder.add_temporaries(std::move(temps)); } /////////////////////////////////////////////////////////////////////////////// // Conv routing /////////////////////////////////////////////////////////////////////////////// void conv_1D_cpu( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_1D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } void conv_2D_cpu( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_2D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } void conv_3D_cpu( const array& in, const array& wt, array out, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_3D( in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, in_dilation, flip, stream); } } // namespace void Convolution::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& in = inputs[0]; auto& wt = inputs[1]; // 3D convolution if (in.ndim() == (3 + 2)) { return conv_3D_cpu( in, wt, out, padding_lo_, padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, flip_, stream()); } // 2D convolution else if (in.ndim() == (2 + 2)) { return conv_2D_cpu( in, wt, out, padding_lo_, padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, flip_, stream()); } // 1D convolution else if (in.ndim() == (1 + 2)) { return conv_1D_cpu( in, wt, out, padding_lo_, padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, flip_, stream()); } // Throw error else { std::ostringstream msg; msg << "[Convolution::eval] Convolution currently only supports" << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2 << " spatial dimensions"; throw std::invalid_argument(msg.str()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/copy.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { namespace { template void copy_single(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); auto size = dst.size(); auto val = static_cast(src_ptr[0]); std::fill_n(dst_ptr, size, val); } template void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); auto size = src.data_size(); std::copy(src_ptr, src_ptr + size, dst_ptr); } template inline void copy_dims( const SrcT* src, DstT* dst, const Shape& shape, const Strides& i_strides, const Strides& o_strides, int axis) { auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = shape[axis]; for (int i = 0; i < N; i++) { if constexpr (D > 1) { copy_dims( src, dst, shape, i_strides, o_strides, axis + 1); } else { *dst = static_cast(*src); } src += stride_src; dst += stride_dst; } } template void copy_general_general( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, const std::optional& dynamic_i_offset, const std::optional& dynamic_o_offset) { auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data() + o_offset; auto i_offset_ptr = dynamic_i_offset ? dynamic_i_offset->data() : nullptr; auto o_offset_ptr = dynamic_o_offset ? dynamic_o_offset->data() : nullptr; auto size = src.size(); if (data_shape.empty()) { auto val = static_cast(*src_ptr); *dst_ptr = val; return; } auto [shape, strides] = collapse_contiguous_dims(data_shape, {i_strides, o_strides}); int ndim = shape.size(); if (ndim < 3) { if (i_offset_ptr) { src_ptr += i_offset_ptr[0]; } if (o_offset_ptr) { dst_ptr += o_offset_ptr[0]; } if (ndim == 1) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } else if (ndim == 2) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } else if (ndim == 3) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } return; } if (i_offset_ptr) { src_ptr += i_offset_ptr[0]; } if (o_offset_ptr) { dst_ptr += o_offset_ptr[0]; } ContiguousIterator in(shape, strides[0], ndim - 3); ContiguousIterator out(shape, strides[1], ndim - 3); auto stride = std::accumulate( shape.end() - 3, shape.end(), 1, std::multiplies()); for (int64_t elem = 0; elem < size; elem += stride) { copy_dims( src_ptr + in.loc, dst_ptr + out.loc, shape, strides[0], strides[1], ndim - 3); in.step(); out.step(); } } template inline void copy_general_general(const array& src, array& dst) { copy_general_general( src, dst, src.shape(), src.strides(), dst.strides(), 0, 0, std::nullopt, std::nullopt); } template void copy_general( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides&, int64_t i_offset, int64_t o_offset, const std::optional& dynamic_i_offset, const std::optional& dynamic_o_offset) { copy_general_general( src, dst, data_shape, i_strides, make_contiguous_strides(data_shape), i_offset, o_offset, dynamic_i_offset, dynamic_o_offset); } template inline void copy_general(const array& src, array& dst) { copy_general_general( src, dst, src.shape(), src.strides(), make_contiguous_strides(src.shape()), 0, 0, std::nullopt, std::nullopt); } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); return; case CopyType::Vector: copy_vector(src, dst); return; case CopyType::General: copy_general(src, dst, std::forward(args)...); return; case CopyType::GeneralGeneral: copy_general_general(src, dst, std::forward(args)...); return; } } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (dst.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case float64: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } template inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, Args&&... args) { switch (src.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case float64: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } } // namespace void copy_cpu_inplace( const array& src, array& dst, CopyType ctype, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(src); encoder.set_output_array(dst); encoder.dispatch( [src = array::unsafe_weak_copy(src), dst = array::unsafe_weak_copy(dst), ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); } void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) { bool donated = set_copy_output_data(src, dst, ctype); if (donated && src.dtype() == dst.dtype()) { // If the output has the same type as the input then there is nothing to // copy, just use the buffer. return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_cpu_inplace(src, dst, ctype, stream); } void copy_cpu_inplace( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, Stream stream, const std::optional& dynamic_i_offset, /* = std::nullopt */ const std::optional& dynamic_o_offset /* = std::nullopt */) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(src); encoder.set_output_array(dst); auto weak_copy_if_set = [](auto x) -> std::optional { if (x) { return array::unsafe_weak_copy(*x); } else { return std::nullopt; } }; encoder.dispatch( [src = array::unsafe_weak_copy(src), dst = array::unsafe_weak_copy(dst), data_shape, i_strides, o_strides, i_offset, o_offset, ctype, dynamic_i_offset = weak_copy_if_set(dynamic_i_offset), dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: copy_inplace_dispatch( src, dst, ctype, data_shape, i_strides, o_strides, i_offset, o_offset, dynamic_i_offset, dynamic_o_offset); break; case CopyType::Scalar: case CopyType::Vector: copy_inplace_dispatch(src, dst, ctype); } }); } array contiguous_copy_cpu(const array& arr, Stream stream) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_cpu(arr, arr_copy, CopyType::General, stream); return arr_copy; } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/copy.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" namespace mlx::core { void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); void copy_cpu_inplace( const array& src, array& dst, CopyType ctype, Stream stream); void copy_cpu_inplace( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, Stream stream, const std::optional& dynamic_i_offset = std::nullopt, const std::optional& dynamic_o_offset = std::nullopt); // Return a contiguous array with same shape that copies the data of |arr|. array contiguous_copy_cpu(const array& arr, Stream stream); } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/device_info.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cpu/device_info.h" #ifdef __APPLE__ #include #include #elif defined(_WIN32) #include #else #include #include #endif namespace mlx::core::cpu { namespace { // Get CPU architecture string at runtime std::string get_cpu_architecture() { #ifdef _WIN32 // Use GetNativeSystemInfo to get the actual hardware architecture, // even when running under WoW64 emulation SYSTEM_INFO sysInfo; GetNativeSystemInfo(&sysInfo); switch (sysInfo.wProcessorArchitecture) { case PROCESSOR_ARCHITECTURE_AMD64: return "x86_64"; case PROCESSOR_ARCHITECTURE_ARM64: return "arm64"; case PROCESSOR_ARCHITECTURE_INTEL: return "x86"; case PROCESSOR_ARCHITECTURE_ARM: return "arm"; default: return "unknown"; } #else // Use uname() for runtime detection on Unix-like systems. // This returns the actual hardware architecture (e.g., "arm64" on Apple // Silicon even when running x86_64 binaries via Rosetta 2) struct utsname info; if (uname(&info) == 0) { return std::string(info.machine); } return "unknown"; #endif } // Get CPU device name (brand string) std::string get_cpu_name() { #ifdef __APPLE__ char model[256]; size_t len = sizeof(model); if (sysctlbyname("machdep.cpu.brand_string", &model, &len, NULL, 0) == 0) { return std::string(model); } #elif defined(_WIN32) // Read CPU brand string from registry HKEY hKey; if (RegOpenKeyExA( HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", 0, KEY_READ, &hKey) == ERROR_SUCCESS) { char brand[256]; DWORD size = sizeof(brand); if (RegQueryValueExA( hKey, "ProcessorNameString", NULL, NULL, (LPBYTE)brand, &size) == ERROR_SUCCESS) { RegCloseKey(hKey); return std::string(brand); } RegCloseKey(hKey); } #else // Try reading from /proc/cpuinfo on Linux std::ifstream cpuinfo("/proc/cpuinfo"); if (cpuinfo.is_open()) { std::string line; while (std::getline(cpuinfo, line)) { if (line.starts_with("model name")) { if (auto n = line.find(": "); n != std::string::npos) { return line.substr(n + 2); } } } } #endif return get_cpu_architecture(); } } // anonymous namespace bool is_available() { return true; } int device_count() { return 1; } const std::unordered_map>& device_info(int /* device_index */) { static auto info = std::unordered_map>{ {"device_name", get_cpu_name()}, {"architecture", get_cpu_architecture()}}; return info; } } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/device_info.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include #include #include namespace mlx::core::cpu { bool is_available(); /** * Get the number of available CPU devices. * * For CPU, always returns 1. */ int device_count(); /** * Get CPU device information. * * Returns a map with basic CPU device properties. */ const std::unordered_map>& device_info(int device_index = 0); } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/distributed.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/primitives.h" namespace mlx::core::distributed { std::pair ensure_row_contiguous(const array& arr, Stream stream) { if (arr.flags().row_contiguous) { return {arr, false}; } else { return {contiguous_copy_cpu(arr, stream), true}; } }; void AllReduce::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto donate_or_copy = [s = stream()](const array& in, array& out) { if (in.flags().row_contiguous) { if (in.is_donatable()) { out.copy_shared_buffer(in); } else { out.set_data(allocator::malloc(out.nbytes())); } return in; } else { array arr_copy = contiguous_copy_cpu(in, s); out.copy_shared_buffer(arr_copy); return arr_copy; } }; auto in = donate_or_copy(inputs[0], outputs[0]); switch (reduce_type_) { case Sum: distributed::detail::all_sum(group(), in, outputs[0], stream()); break; case Max: distributed::detail::all_max(group(), in, outputs[0], stream()); break; case Min: distributed::detail::all_min(group(), in, outputs[0], stream()); break; default: throw std::runtime_error( "Only all reduce sum, min and max are supported for now"); } } void AllGather::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::all_gather(group(), in, outputs[0], stream()); if (copied) { auto& enc = cpu::get_command_encoder(stream()); enc.add_temporary(in); } } void Send::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); distributed::detail::send(group(), in, dst_, stream()); outputs[0].copy_shared_buffer(inputs[0]); if (copied) { auto& enc = cpu::get_command_encoder(stream()); enc.add_temporary(in); } } void Recv::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 0); assert(outputs.size() == 1); outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::recv(group(), outputs[0], src_, stream()); } void ReduceScatter::eval_cpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[ReduceScatter] Not implemented yet."); } } // namespace mlx::core::distributed ================================================ FILE: mlx/backend/cpu/eig.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template complex64_t to_complex(T r, T i) { return {static_cast(r), static_cast(i)}; } template struct EigWork {}; template struct EigWork< T, typename std::enable_if::value>::type> { using O = complex64_t; char jobl; char jobr; int N; int lwork; int info; std::vector buffers; EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) { T work; int n_vecs_l = compute_eigenvectors ? N_ : 1; int n_vecs_r = 1; geev( &jobl, &jobr, &N, nullptr, &N, nullptr, nullptr, nullptr, &n_vecs_l, nullptr, &n_vecs_r, &work, &lwork, &info); lwork = static_cast(work); buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2)); if (compute_eigenvectors) { buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2)); } buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); } void run(T* a, O* values, O* vectors) { auto eig_tmp = static_cast(buffers[0].buffer.raw_ptr()); T* vec_tmp = nullptr; if (vectors) { vec_tmp = static_cast(buffers[1].buffer.raw_ptr()); } auto work = static_cast(buffers.back().buffer.raw_ptr()); int n_vecs_l = vectors ? N : 1; int n_vecs_r = 1; geev( &jobl, &jobr, &N, a, &N, eig_tmp, eig_tmp + N, vectors ? vec_tmp : nullptr, &n_vecs_l, nullptr, &n_vecs_r, work, &lwork, &info); for (int i = 0; i < N; ++i) { values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]); } if (vectors) { for (int i = 0; i < N; ++i) { if (values[i].imag() != 0) { for (int j = 0; j < N; ++j) { vectors[i * N + j] = to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]); vectors[(i + 1) * N + j] = to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]); } i += 1; } else { for (int j = 0; j < N; ++j) { vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0)); } } } } } }; template <> struct EigWork> { using T = std::complex; using R = float; using O = T; char jobl; char jobr; int N; int lwork; int lrwork; int info; std::vector buffers; EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) { T work; R rwork; int n_vecs_l = compute_eigenvectors ? N_ : 1; int n_vecs_r = 1; geev( &jobl, &jobr, &N, nullptr, &N, nullptr, nullptr, &n_vecs_l, nullptr, &n_vecs_r, &work, &lwork, &rwork, &info); lwork = static_cast(work.real()); buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); } void run(T* a, T* values, T* vectors) { int n_vecs_l = vectors ? N : 1; int n_vecs_r = 1; geev( &jobl, &jobr, &N, a, &N, values, vectors, &n_vecs_l, nullptr, &n_vecs_r, static_cast(buffers[0].buffer.raw_ptr()), &lwork, static_cast(buffers[1].buffer.raw_ptr()), &info); } }; template void eig_impl( array& a, array& vectors, array& values, bool compute_eigenvectors, Stream stream) { auto a_ptr = a.data(); auto val_ptr = values.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(values); complex64_t* vec_ptr = nullptr; if (compute_eigenvectors) { encoder.set_output_array(vectors); vec_ptr = vectors.data(); } encoder.dispatch([a_ptr, val_ptr, vec_ptr, compute_eigenvectors, N = vectors.shape(-1), size = vectors.size()]() mutable { char jobr = 'N'; char jobl = compute_eigenvectors ? 'V' : 'N'; EigWork work(jobl, jobr, N, compute_eigenvectors); for (size_t i = 0; i < size / (N * N); ++i) { work.run(a_ptr, val_ptr, vec_ptr); a_ptr += N * N; val_ptr += N; if (vec_ptr) { vec_ptr += N * N; } if (work.info != 0) { std::stringstream msg; msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " << work.info; throw std::runtime_error(msg.str()); } } }); encoder.add_temporary(a); } } // namespace void Eig::eval_cpu( const std::vector& inputs, std::vector& outputs) { const auto& a = inputs[0]; auto& values = outputs[0]; auto vectors = compute_eigenvectors_ ? outputs[1] : array(a.shape(), complex64, nullptr, {}); auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); copy_cpu( a, a_copy, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream()); values.set_data(allocator::malloc(values.nbytes())); if (compute_eigenvectors_) { // Set the strides and flags so the eigenvectors // are in the columns of the output auto flags = vectors.flags(); auto strides = vectors.strides(); auto ndim = a.ndim(); std::swap(strides[ndim - 1], strides[ndim - 2]); if (a.size() > 1) { flags.row_contiguous = false; if (ndim > 2) { flags.col_contiguous = false; } else { flags.col_contiguous = true; } } vectors.set_data( allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags); } switch (a.dtype()) { case float32: eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); break; case float64: eig_impl( a_copy, vectors, values, compute_eigenvectors_, stream()); break; case complex64: eig_impl>( a_copy, vectors, values, compute_eigenvectors_, stream()); break; default: throw std::runtime_error( "[Eig::eval_cpu] only supports float32, float64, or complex64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/eigh.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template struct EighWork {}; template struct EighWork< T, typename std::enable_if::value>::type> { using R = T; char jobz; char uplo; int N; int lwork; int liwork; int info; std::vector buffers; EighWork(char jobz_, char uplo_, int N_) : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) { T work; int iwork; syevd( &jobz, &uplo, &N, nullptr, &N, nullptr, &work, &lwork, &iwork, &liwork, &info); lwork = static_cast(work); liwork = iwork; buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); } void run(T* vectors, T* values) { syevd( &jobz, &uplo, &N, vectors, &N, values, static_cast(buffers[0].buffer.raw_ptr()), &lwork, static_cast(buffers[1].buffer.raw_ptr()), &liwork, &info); } }; template <> struct EighWork> { using T = std::complex; using R = float; char jobz; char uplo; int N; int lwork; int lrwork; int liwork; int info; std::vector buffers; EighWork(char jobz_, char uplo_, int N_) : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) { T work; R rwork; int iwork; heevd( &jobz, &uplo, &N, nullptr, &N, nullptr, &work, &lwork, &rwork, &lrwork, &iwork, &liwork, &info); lwork = static_cast(work.real()); lrwork = static_cast(rwork); liwork = iwork; buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); } void run(T* vectors, R* values) { heevd( &jobz, &uplo, &N, vectors, &N, values, static_cast(buffers[0].buffer.raw_ptr()), &lwork, static_cast(buffers[1].buffer.raw_ptr()), &lrwork, static_cast(buffers[2].buffer.raw_ptr()), &liwork, &info); if (jobz == 'V') { // We have pre-transposed the vectors but we also must conjugate them // when they are complex. // // We could vectorize this but it is so fast in comparison to heevd that // it doesn't really matter. for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { *vectors = std::conj(*vectors); vectors++; } } } } }; template void eigh_impl( array& vectors, array& values, const std::string& uplo, bool compute_eigenvectors, Stream stream) { using R = typename EighWork::R; auto vec_ptr = vectors.data(); auto eig_ptr = values.data(); char jobz = compute_eigenvectors ? 'V' : 'N'; auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(vectors); encoder.set_output_array(values); encoder.dispatch([vec_ptr, eig_ptr, jobz, uplo = uplo[0], N = vectors.shape(-1), size = vectors.size()]() mutable { // Work query EighWork work(jobz, uplo, N); // Work loop for (size_t i = 0; i < size / (N * N); ++i) { work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; if (work.info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " << work.info; throw std::runtime_error(msg.str()); } } }); if (!compute_eigenvectors) { encoder.add_temporary(vectors); } } } // namespace void Eigh::eval_cpu( const std::vector& inputs, std::vector& outputs) { const auto& a = inputs[0]; auto& values = outputs[0]; auto vectors = compute_eigenvectors_ ? outputs[1] : array(a.shape(), a.dtype(), nullptr, {}); values.set_data(allocator::malloc(values.nbytes())); copy_cpu( a, vectors, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream()); if (compute_eigenvectors_) { // Set the strides and flags so the eigenvectors // are in the columns of the output auto flags = vectors.flags(); auto strides = vectors.strides(); auto ndim = a.ndim(); std::swap(strides[ndim - 1], strides[ndim - 2]); if (a.size() > 1) { flags.row_contiguous = false; if (ndim > 2) { flags.col_contiguous = false; } else { flags.col_contiguous = true; } } vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size()); } switch (a.dtype()) { case float32: eigh_impl(vectors, values, uplo_, compute_eigenvectors_, stream()); break; case float64: eigh_impl( vectors, values, uplo_, compute_eigenvectors_, stream()); break; case complex64: eigh_impl>( vectors, values, uplo_, compute_eigenvectors_, stream()); break; default: throw std::runtime_error( "[Eigh::eval_cpu] only supports float32 or float64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/encoder.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cpu/encoder.h" namespace mlx::core::cpu { CommandEncoder& get_command_encoder(Stream stream) { static std::unordered_map encoder_map; auto it = encoder_map.find(stream.index); if (it == encoder_map.end()) { it = encoder_map.emplace(stream.index, stream).first; } return it->second; } } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/encoder.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/scheduler.h" namespace mlx::core::cpu { // Number of dispatches per scheduler task constexpr int DISPATCHES_PER_TASK = 10; struct MLX_API CommandEncoder { CommandEncoder(Stream stream) : stream_(stream) {} CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; CommandEncoder(CommandEncoder&&) = delete; CommandEncoder& operator=(CommandEncoder&&) = delete; void set_input_array(const array& a) {} void set_output_array(array& a) {} // Hold onto a temporary until any already scheduled tasks which use it as // an input are complete. void add_temporary(array arr) { temporaries_.push_back(std::move(arr)); } void add_temporaries(std::vector arrays) { temporaries_.insert( temporaries_.end(), std::make_move_iterator(arrays.begin()), std::make_move_iterator(arrays.end())); } std::vector& temporaries() { return temporaries_; } template void dispatch(F&& f, Args&&... args) { num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; auto task = std::bind(std::forward(f), std::forward(args)...); if (num_ops_ == 0) { scheduler::notify_new_task(stream_); auto task_wrap = [s = stream_, task = std::move(task)]() mutable { task(); scheduler::notify_task_completion(s); }; scheduler::enqueue(stream_, std::move(task_wrap)); } else { scheduler::enqueue(stream_, std::move(task)); } } private: Stream stream_; std::vector temporaries_; int num_ops_{0}; }; MLX_API CommandEncoder& get_command_encoder(Stream stream); } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/eval.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cpu/eval.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core::cpu { void eval(array& arr) { auto s = arr.primitive().stream(); auto outputs = arr.outputs(); { // If the array is a tracer hold a reference // to its inputs so they don't get donated std::vector inputs; if (arr.is_tracer()) { inputs = arr.inputs(); } arr.primitive().eval_cpu(arr.inputs(), outputs); } std::unordered_set> buffers; for (auto& in : arr.inputs()) { buffers.insert(in.data_shared_ptr()); } for (auto& s : arr.siblings()) { buffers.insert(s.data_shared_ptr()); } // Remove the output if it was donated to by an input if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { buffers.erase(it); } auto& encoder = cpu::get_command_encoder(s); encoder.dispatch([buffers = std::move(buffers), temps = std::move(encoder.temporaries())]() {}); } } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/eval.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/stream.h" namespace mlx::core::cpu { void eval(array& arr); } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/cpu/fft.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/3rdparty/pocketfft.h" #include "mlx/allocator.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { void FFT::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; std::vector strides_in( in.strides().begin(), in.strides().end()); for (auto& s : strides_in) { s *= in.itemsize(); } std::vector strides_out( out.strides().begin(), out.strides().end()); for (auto& s : strides_out) { s *= out.itemsize(); } out.set_data(allocator::malloc(out.nbytes())); std::vector shape; if (out.dtype() == float32) { shape.insert(shape.end(), out.shape().begin(), out.shape().end()); } else { shape.insert(shape.end(), in.shape().begin(), in.shape().end()); } float scale = 1.0f; if (inverse_) { size_t nelem = std::accumulate( axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) { return x * shape[y]; }); scale /= nelem; } auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); if (in.dtype() == complex64 && out.dtype() == complex64) { auto in_ptr = reinterpret_cast*>(in.data()); auto out_ptr = reinterpret_cast*>(out.data()); encoder.dispatch([shape = std::move(shape), strides_in = std::move(strides_in), strides_out = std::move(strides_out), axes = axes_, inverse = inverse_, in_ptr, out_ptr, scale]() { pocketfft::c2c( shape, strides_in, strides_out, axes, !inverse, in_ptr, out_ptr, scale); }); } else if (in.dtype() == float32 && out.dtype() == complex64) { auto in_ptr = in.data(); auto out_ptr = reinterpret_cast*>(out.data()); encoder.dispatch([shape = std::move(shape), strides_in = std::move(strides_in), strides_out = std::move(strides_out), axes = axes_, inverse = inverse_, in_ptr, out_ptr, scale]() { pocketfft::r2c( shape, strides_in, strides_out, axes, !inverse, in_ptr, out_ptr, scale); }); } else if (in.dtype() == complex64 && out.dtype() == float32) { auto in_ptr = reinterpret_cast*>(in.data()); auto out_ptr = out.data(); encoder.dispatch([shape = std::move(shape), strides_in = std::move(strides_in), strides_out = std::move(strides_out), axes = axes_, inverse = inverse_, in_ptr, out_ptr, scale]() { pocketfft::c2r( shape, strides_in, strides_out, axes, !inverse, in_ptr, out_ptr, scale); }); } else { throw std::runtime_error( "[FFT] Received unexpected input and output type combination."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemm.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { template void matmul( const T* a, const T* b, T* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides); } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemms/bnns.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/array.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/dtype.h" namespace mlx::core { template constexpr BNNSDataType to_bnns_dtype(); template <> constexpr BNNSDataType to_bnns_dtype() { return BNNSDataType(BNNSDataTypeFloatBit | 32); } template <> constexpr BNNSDataType to_bnns_dtype() { return BNNSDataType(BNNSDataTypeFloatBit | 16); } template <> constexpr BNNSDataType to_bnns_dtype() { return BNNSDataTypeBFloat16; } template void matmul_bnns( const T* a, const T* b, T* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; BNNSDataType bnns_dtype = to_bnns_dtype(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" if (beta != 1.0 && beta != 0.0) { // scale the output for (auto i = 0; i < batch_size * M * N; ++i) { out[i] *= beta; } beta = 1.0; } const BNNSLayerParametersBroadcastMatMul gemm_params{ /* float alpha = */ alpha, /* float beta = */ beta, /* bool transA = */ a_transposed, /* bool transB = */ b_transposed, /* bool quadratic = */ false, /* bool a_is_weights = */ false, /* bool b_is_weights = */ false, /* BNNSNDArrayDescriptor iA_desc = */ BNNSNDArrayDescriptor{ /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ {lda, (M * K) / lda, 0, 0, 0, 0, 0, 0}, /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ {1, lda, 0, 0, 0, 0, 0, 0}, /* void * _Nullable data = */ nullptr, /* BNNSDataType data_type = */ bnns_dtype, /* void * _Nullable table_data = */ nullptr, /* BNNSDataType table_data_type = */ bnns_dtype, /* float data_scale = */ 1.0, /* float data_bias = */ 0.0, }, /* BNNSNDArrayDescriptor iB_desc = */ BNNSNDArrayDescriptor{ /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ {ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0}, /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ {1, ldb, 0, 0, 0, 0, 0, 0}, /* void * _Nullable data = */ nullptr, /* BNNSDataType data_type = */ bnns_dtype, /* void * _Nullable table_data = */ nullptr, /* BNNSDataType table_data_type = */ bnns_dtype, /* float data_scale = */ 1.0, /* float data_bias = */ 0.0, }, /* BNNSNDArrayDescriptor o_desc = */ BNNSNDArrayDescriptor{ /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ {N, M, 0, 0, 0, 0, 0, 0}, /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ {1, N, 0, 0, 0, 0, 0, 0}, /* void * _Nullable data = */ nullptr, /* BNNSDataType data_type = */ bnns_dtype, /* void * _Nullable table_data = */ nullptr, /* BNNSDataType table_data_type = */ bnns_dtype, /* float data_scale = */ 1.0, /* float data_bias = */ 0.0, }, }; auto bnns_filter = BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); for (int i = 0; i < batch_size; ++i) { BNNSFilterApplyTwoInput( bnns_filter, reinterpret_cast( a + elem_to_loc(M * K * i, a_shape, a_strides)), reinterpret_cast( b + elem_to_loc(K * N * i, b_shape, b_strides)), reinterpret_cast(out + M * N * i)); } BNNSFilterDestroy(bnns_filter); #pragma GCC diagnostic pop } template <> void matmul( const float16_t* a, const float16_t* b, float16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { matmul_bnns( a, b, out, a_transposed, b_transposed, lda, ldb, ldc, alpha, beta, batch_size, a_shape, a_strides, b_shape, b_strides); } template <> void matmul( const bfloat16_t* a, const bfloat16_t* b, bfloat16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { matmul_bnns( a, b, out, a_transposed, b_transposed, lda, ldb, ldc, alpha, beta, batch_size, a_shape, a_strides, b_shape, b_strides); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemms/cblas.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" namespace mlx::core { template <> void matmul( const float* a, const float* b, float* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_sgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA b_transposed ? CblasTrans : CblasNoTrans, // transB M, N, K, alpha, a + elem_to_loc(M * K * i, a_shape, a_strides), lda, b + elem_to_loc(K * N * i, b_shape, b_strides), ldb, beta, out + M * N * i, ldc); } } template <> void matmul( const double* a, const double* b, double* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_dgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA b_transposed ? CblasTrans : CblasNoTrans, // transB M, N, K, alpha, a + elem_to_loc(M * K * i, a_shape, a_strides), lda, b + elem_to_loc(K * N * i, b_shape, b_strides), ldb, beta, out + M * N * i, ldc); } } template <> void matmul( const complex64_t* a, const complex64_t* b, complex64_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; auto calpha = static_cast(alpha); auto cbeta = static_cast(beta); for (int i = 0; i < batch_size; ++i) { cblas_cgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA b_transposed ? CblasTrans : CblasNoTrans, // transB M, N, K, &calpha, a + elem_to_loc(M * K * i, a_shape, a_strides), lda, b + elem_to_loc(K * N * i, b_shape, b_strides), ldb, &cbeta, out + M * N * i, ldc); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemms/simd_bf16.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemms/simd_gemm.h" namespace mlx::core { template <> void matmul( const bfloat16_t* a, const bfloat16_t* b, bfloat16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { simd_gemm( a + elem_to_loc(M * K * i, a_shape, a_strides), b + elem_to_loc(K * N * i, b_shape, b_strides), out + M * N * i, a_transposed, b_transposed, M, N, K, alpha, beta); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemms/simd_fp16.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemms/simd_gemm.h" namespace mlx::core { template <> void matmul( const float16_t* a, const float16_t* b, float16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, size_t ldc, float alpha, float beta, size_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); size_t M = a_shape[ndim - 2]; size_t N = b_shape[ndim - 1]; size_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { simd_gemm( a + elem_to_loc(M * K * i, a_shape, a_strides), b + elem_to_loc(K * N * i, b_shape, b_strides), out + M * N * i, a_transposed, b_transposed, M, N, K, alpha, beta); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/gemms/simd_gemm.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { inline int ceildiv(int a, int b) { return (a + b - 1) / b; } template void load_block( const T* in, AccT* out, int M, int N, int i, int j, bool transpose) { if (transpose) { for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { out[jj * block_size + ii] = in[(i * block_size + ii) * N + j * block_size + jj]; } } } else { for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { out[ii * block_size + jj] = in[(i * block_size + ii) * N + j * block_size + jj]; } } } } template void simd_gemm( const T* a, const T* b, T* c, bool a_trans, bool b_trans, int M, int N, int K, float alpha, float beta) { constexpr int block_size = 16; constexpr int simd_size = simd::max_size; static_assert( (block_size % simd_size) == 0, "Block size must be divisible by SIMD size"); int last_k_block_size = K - block_size * (K / block_size); int last_k_simd_block = (last_k_block_size / simd_size) * simd_size; for (int i = 0; i < ceildiv(M, block_size); i++) { for (int j = 0; j < ceildiv(N, block_size); j++) { AccT c_block[block_size * block_size] = {0.0}; AccT a_block[block_size * block_size]; AccT b_block[block_size * block_size]; int k = 0; for (; k < K / block_size; k++) { // Load a and b blocks if (a_trans) { load_block(a, a_block, K, M, k, i, true); } else { load_block(a, a_block, M, K, i, k, false); } if (b_trans) { load_block(b, b_block, N, K, j, k, false); } else { load_block(b, b_block, K, N, k, j, true); } // Multiply and accumulate for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { for (int kk = 0; kk < block_size; kk += simd_size) { auto av = simd::load(a_block + ii * block_size + kk); auto bv = simd::load(b_block + jj * block_size + kk); c_block[ii * block_size + jj] += simd::sum(av * bv); } } } } if (last_k_block_size) { // Load a and b blocks if (a_trans) { load_block(a, a_block, K, M, k, i, true); } else { load_block(a, a_block, M, K, i, k, false); } if (b_trans) { load_block(b, b_block, N, K, j, k, false); } else { load_block(b, b_block, K, N, k, j, true); } // Multiply and accumulate for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { int kk = 0; for (; kk < last_k_simd_block; kk += simd_size) { auto av = simd::load(a_block + ii * block_size + kk); auto bv = simd::load(b_block + jj * block_size + kk); c_block[ii * block_size + jj] += simd::sum(av * bv); } for (; kk < last_k_block_size; ++kk) { c_block[ii * block_size + jj] += a_block[ii * block_size + kk] * b_block[jj * block_size + kk]; } } } } // Store for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { auto c_idx = (i * block_size + ii) * N + j * block_size + jj; if (beta != 0) { c[c_idx] = static_cast( alpha * c_block[ii * block_size + jj] + beta * c[c_idx]); } else { c[c_idx] = static_cast(alpha * c_block[ii * block_size + jj]); } } } } } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/hadamard.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/common/hadamard.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { // n = 2^k component template void hadamard_n(T* out, int n, int m, float scale, size_t size) { for (int b = 0; b < size / n; b++) { size_t loc = b * n; T* data_ptr = out + loc; int h = 1; int n_over_2 = n / 2; while (h < n) { for (int i = 0; i < n / 2; i++) { int k = i & (h - 1); int j = ((i - k) << 1) + k; float x = *(data_ptr + j); float y = *(data_ptr + j + h); *(data_ptr + j) = x + y; *(data_ptr + j + h) = x - y; if (h == n_over_2) { *(data_ptr + j) *= scale; *(data_ptr + j + h) *= scale; } } h <<= 1; } } } // m component template void hadamard_m(T* out, int n, int m, float scale, size_t size) { auto h_matrices = hadamard_matrices(); auto& matrix = h_matrices[m]; auto start = 1; auto end = matrix.find('\n', start); std::vector hmat_vec; while (end != std::string_view::npos) { auto row = matrix.substr(start, end - start); for (int i = 0; i < row.length(); i++) { hmat_vec.push_back(row[i] == '+'); } start = end + 1; end = matrix.find('\n', start); } for (int b = 0; b < size / m / n; b++) { size_t loc = b * n * m; T* data_ptr = out + loc; for (int i = 0; i < n; i++) { std::vector out(m); for (int j = 0; j < m; j++) { for (int k = 0; k < m; k++) { float x = *(data_ptr + i + k * n); if (hmat_vec[k + j * m]) { out[j] += x; } else { out[j] -= x; } } } for (int j = 0; j < m; j++) { *(data_ptr + i + j * n) = out[j] * scale; } } } } template void hadamard(array& out, int n, int m, float scale, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); auto out_ptr = out.data(); encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() { float n_scale = m > 1 ? 1.0 : scale; hadamard_n(out_ptr, n, m, n_scale, size); if (m > 1) { hadamard_m(out_ptr, n, m, scale, size); } }); } void Hadamard::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; // Copy input to output if (in.flags().row_contiguous && in.is_donatable()) { out.copy_shared_buffer(in); } else { copy_cpu( in, out, in.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream()); } int axis = out.ndim() - 1; auto [n, m] = decompose_hadamard(out.shape(axis)); switch (in.dtype()) { case float32: return hadamard(out, n, m, scale_, stream()); case float16: return hadamard(out, n, m, scale_, stream()); case bfloat16: return hadamard(out, n, m, scale_, stream()); default: throw std::invalid_argument("[hadamard] Unsupported type."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/indexing.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/slicing.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" namespace mlx::core { template inline size_t offset_neg_idx(IdxT idx, size_t size) { return (idx < 0) ? idx + size : idx; } template <> inline size_t offset_neg_idx(uint32_t idx, size_t) { return idx; } struct None { template void operator()(T x, T* y) { (*y) = x; } }; struct Sum { template void operator()(T x, T* y) { (*y) += x; } }; struct Prod { template void operator()(T x, T* y) { (*y) *= x; } }; struct Max { template void operator()(T x, T* y) { (*y) = (*y > x) ? *y : x; } }; struct Min { template void operator()(T x, T* y) { (*y) = (*y < x) ? *y : x; } }; template void gather( const array& src, const std::vector& inds, array& out, const std::vector& axes, const Shape& slice_sizes) { // If the array is row contiguous then we can do a contiguous copy given // two conditions on the slice size: // - Any number of leading ones in the slice sizes are allowed // - All other slice sizes match the corresponding dimension except the // first non-singleton slice size // If the array is col contiguous then the reverse is the case: // - Any number of trailing ones in the slice sizes are allowed // - All other slice sizes match the corresponding dimension except the // first non-singleton slice size from the end bool can_copy = false; if (src.flags().row_contiguous) { can_copy = true; // Ignore leading 1s int i = 0; for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) ; // Check the remaining i++; for (; i < src.ndim() && can_copy; ++i) { can_copy = (src.shape(i) == slice_sizes[i]); } } else if (src.flags().col_contiguous) { can_copy = true; // Ignore trailing 1s int i = slice_sizes.size() - 1; for (; i >= 0 && slice_sizes[i] == 1; --i) ; // Skip the next slice size and check the remaining i--; for (; i >= 0 && can_copy; --i) { can_copy = (src.shape(i) == slice_sizes[i]); } } size_t slice_size = 1; for (auto s : slice_sizes) { slice_size *= s; } size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; const T* src_ptr = src.data(); T* dst_ptr = out.data(); std::vector its(inds.begin(), inds.end()); ContiguousIterator src_it; if (!can_copy && src.ndim() > 0) { src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } size_t out_idx = 0; for (int idx = 0; idx < ind_size; idx++) { size_t src_idx = 0; for (int ii = 0; ii < inds.size(); ++ii) { auto ax = axes[ii]; auto idx_loc = its[ii].loc; its[ii].step(); auto idx_val = offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); src_idx += (idx_val * src.strides()[ax]); } if (slice_size == 1) { dst_ptr[out_idx++] = src_ptr[src_idx]; } else if (can_copy) { std::copy( src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); out_idx += slice_size; } else { for (int jj = 0; jj < slice_size; jj++) { dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; src_it.step(); } src_it.reset(); } } } template void dispatch_gather( const array& src, const std::vector& inds, array& out, const std::vector& axes, const Shape& size) { switch (out.dtype()) { case bool_: gather(src, inds, out, axes, size); break; case uint8: gather(src, inds, out, axes, size); break; case uint16: gather(src, inds, out, axes, size); break; case uint32: gather(src, inds, out, axes, size); break; case uint64: gather(src, inds, out, axes, size); break; case int8: gather(src, inds, out, axes, size); break; case int16: gather(src, inds, out, axes, size); break; case int32: gather(src, inds, out, axes, size); break; case int64: gather(src, inds, out, axes, size); break; case float16: gather(src, inds, out, axes, size); break; case float32: gather(src, inds, out, axes, size); break; case float64: gather(src, inds, out, axes, size); break; case bfloat16: gather(src, inds, out, axes, size); break; case complex64: gather(src, inds, out, axes, size); break; } } void Gather::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; std::vector inds; for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) { inds.push_back(array::unsafe_weak_copy(*it)); } auto& encoder = cpu::get_command_encoder(stream()); for (auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); encoder.dispatch([axes_ = axes_, slice_sizes_ = slice_sizes_, src = array::unsafe_weak_copy(src), inds = std::move(inds), out = array::unsafe_weak_copy(out)]() mutable { if (inds.empty()) { dispatch_gather(src, inds, out, axes_, slice_sizes_); return; } switch (inds[0].dtype()) { case uint8: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint16: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint32: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint64: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int8: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int16: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int32: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int64: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; default: throw std::runtime_error( "[Gather::eval_cpu] Cannot gather with indices type."); break; } }); } template void gather_axis( const array& src, const array& ind, array& out, const int axis) { auto shape = remove_index(ind.shape(), axis); ContiguousIterator ind_it( shape, remove_index(ind.strides(), axis), src.ndim() - 1); ContiguousIterator src_it( shape, remove_index(src.strides(), axis), src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); auto dst_ptr = out.data(); auto ind_ax_stride = ind.strides(axis); auto src_ax_stride = src.strides(axis); auto dst_ax_stride = out.strides(axis); auto ind_ax_size = ind.shape(axis); auto src_ax_size = src.shape(axis); size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { size_pre *= ind.shape(i); } for (int i = axis + 1; i < ind.ndim(); ++i) { size_post *= ind.shape(i); } size_t stride_pre = size_post * ind_ax_size; for (size_t i = 0; i < size_pre; i++) { for (size_t k = 0; k < size_post; k++) { for (int j = 0; j < ind_ax_size; ++j) { auto ind_val = offset_neg_idx( ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); dst_ptr[k + j * dst_ax_stride] = src_ptr[src_it.loc + ind_val * src_ax_stride]; } ind_it.step(); src_it.step(); } dst_ptr += stride_pre; } } template void dispatch_gather_axis( const array& src, const array& inds, array& out, const int axis) { switch (out.dtype()) { case bool_: gather_axis(src, inds, out, axis); break; case uint8: gather_axis(src, inds, out, axis); break; case uint16: gather_axis(src, inds, out, axis); break; case uint32: gather_axis(src, inds, out, axis); break; case uint64: gather_axis(src, inds, out, axis); break; case int8: gather_axis(src, inds, out, axis); break; case int16: gather_axis(src, inds, out, axis); break; case int32: gather_axis(src, inds, out, axis); break; case int64: gather_axis(src, inds, out, axis); break; case float16: gather_axis(src, inds, out, axis); break; case float32: gather_axis(src, inds, out, axis); break; case float64: gather_axis(src, inds, out, axis); break; case bfloat16: gather_axis(src, inds, out, axis); break; case complex64: gather_axis(src, inds, out, axis); break; } } void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; auto& inds = inputs[1]; auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(src); encoder.set_input_array(inds); encoder.set_output_array(out); encoder.dispatch([axis_ = axis_, src = array::unsafe_weak_copy(src), inds = array::unsafe_weak_copy(inds), out = array::unsafe_weak_copy(out)]() mutable { switch (inds.dtype()) { case uint8: dispatch_gather_axis(src, inds, out, axis_); break; case uint16: dispatch_gather_axis(src, inds, out, axis_); break; case uint32: dispatch_gather_axis(src, inds, out, axis_); break; case uint64: dispatch_gather_axis(src, inds, out, axis_); break; case int8: dispatch_gather_axis(src, inds, out, axis_); break; case int16: dispatch_gather_axis(src, inds, out, axis_); break; case int32: dispatch_gather_axis(src, inds, out, axis_); break; case int64: dispatch_gather_axis(src, inds, out, axis_); break; default: throw std::runtime_error( "[GatherAxis::eval_cpu] Cannot gather with indices type."); break; } }); } template void scatter( const array& updates, array& out, const std::vector& inds, const std::vector& axes) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); size_t n_updates = nind ? inds[0].size() : 1; Shape update_shape( updates.shape().begin() + inds_ndim, updates.shape().end()); size_t update_size = 1; for (auto us : update_shape) { update_size *= us; } std::vector its(inds.begin(), inds.end()); ContiguousIterator update_it(updates); ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); auto out_ptr = out.data(); auto upd_ptr = updates.data(); for (int i = 0; i < n_updates; ++i) { size_t out_offset = 0; for (int j = 0; j < inds.size(); ++j) { auto ax = axes[j]; auto idx_loc = its[j].loc; its[j].step(); auto idx_val = offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); out_offset += (idx_val * out.strides()[ax]); } update_it.seek(i * update_size); for (int j = 0; j < update_size; ++j) { OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); update_it.step(); out_it.step(); } out_it.reset(); update_it.reset(); } } template void dispatch_scatter_inds( array& out, const std::vector& indices, const array& updates, const std::vector& axes, Scatter::ReduceType rtype) { switch (rtype) { case Scatter::None: scatter(updates, out, indices, axes); break; case Scatter::Sum: scatter(updates, out, indices, axes); break; case Scatter::Prod: scatter(updates, out, indices, axes); break; case Scatter::Max: scatter(updates, out, indices, axes); break; case Scatter::Min: scatter(updates, out, indices, axes); break; } } template void dispatch_scatter( array& out, const std::vector& inds, const array& updates, const std::vector& axes, Scatter::ReduceType rtype) { if (inds.empty()) { dispatch_scatter_inds(out, inds, updates, axes, rtype); return; } switch (inds[0].dtype()) { case uint8: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint16: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint32: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint64: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int8: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int16: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int32: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int64: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; default: throw std::runtime_error( "[Scatter::eval_cpu] Cannot scatter with indices type."); } } void Scatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; auto& updates = inputs.back(); // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy_cpu(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); std::vector inds; for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) { encoder.set_input_array(*it); inds.push_back(array::unsafe_weak_copy(*it)); } encoder.set_input_array(updates); encoder.set_output_array(out); encoder.dispatch([axes_ = axes_, reduce_type_ = reduce_type_, updates = array::unsafe_weak_copy(updates), inds = std::move(inds), out = array::unsafe_weak_copy(out)]() mutable { switch (out.dtype()) { case bool_: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint8: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int8: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case bfloat16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case complex64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; } }); } template void scatter_axis(array& out, const array idx, const array& upd, int axis) { auto shape = remove_index(idx.shape(), axis); ContiguousIterator idx_it( shape, remove_index(idx.strides(), axis), upd.ndim() - 1); ContiguousIterator upd_it( shape, remove_index(upd.strides(), axis), upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); auto dst_ptr = out.data(); auto idx_ax_stride = idx.strides(axis); auto upd_ax_stride = upd.strides(axis); auto dst_ax_stride = out.strides(axis); auto idx_ax_size = idx.shape(axis); auto dst_ax_size = out.shape(axis); size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { size_pre *= idx.shape(i); } for (int i = axis + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } size_t stride_pre = size_post * dst_ax_size; for (size_t i = 0; i < size_pre; i++) { for (size_t k = 0; k < size_post; k++) { for (int j = 0; j < idx_ax_size; ++j) { auto ind_val = offset_neg_idx( idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); OpT{}( upd_ptr[upd_it.loc + j * upd_ax_stride], dst_ptr + k + ind_val * dst_ax_stride); } idx_it.step(); upd_it.step(); } dst_ptr += stride_pre; } } template void dispatch_scatter_axis_op( array& out, const array& idx, const array& updates, int axis, ScatterAxis::ReduceType rtype) { switch (rtype) { case ScatterAxis::None: scatter_axis(out, idx, updates, axis); break; case ScatterAxis::Sum: scatter_axis(out, idx, updates, axis); break; } } template void dispatch_scatter_axis( array& out, const array& idx, const array& updates, int axis, ScatterAxis::ReduceType rtype) { switch (idx.dtype()) { case uint8: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint16: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint32: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint64: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int8: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int16: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int32: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int64: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; default: throw std::runtime_error( "[ScatterAxis::eval_cpu] Cannot scatter with indices type."); } } void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; auto& idx = inputs[1]; auto& updates = inputs[2]; // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy_cpu(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(idx); encoder.set_input_array(updates); encoder.set_output_array(out); encoder.dispatch([axis_ = axis_, reduce_type_ = reduce_type_, idx = array::unsafe_weak_copy(idx), updates = array::unsafe_weak_copy(updates), out = array::unsafe_weak_copy(out)]() mutable { switch (out.dtype()) { case bool_: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint8: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint16: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int8: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int16: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case float16: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; case float32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case float64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case bfloat16: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; case complex64: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; } }); } template void masked_scatter_impl(const array& mask, const array& src, array& out) { ContiguousIterator mask_it(mask); ContiguousIterator src_it(src); ContiguousIterator out_it(out); const bool* mask_ptr = mask.data(); const T* src_ptr = src.data(); T* dst_ptr = out.data(); const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask.size() / batch_count; const size_t src_batch_size = src.size() / batch_count; for (size_t b = 0; b < batch_count; ++b) { size_t src_consumed = 0; src_it.seek(b * src_batch_size); for (size_t i = 0; i < mask_batch_size; ++i) { if (mask_ptr[mask_it.loc]) { if (src_consumed >= src_batch_size) { throw std::runtime_error( "[MaskedScatter::eval_cpu] Source does not have enough elements for mask."); } dst_ptr[out_it.loc] = src_ptr[src_it.loc]; src_it.step(); ++src_consumed; } mask_it.step(); out_it.step(); } } } void MaskedScatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); auto& dst = inputs[0]; auto& mask = inputs[1]; auto& src = inputs[2]; // Copy dst into out (copy allocates memory for out) auto ctype = dst.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy_cpu(dst, out, ctype, stream()); if (mask.size() == 0) { return; } auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(mask); encoder.set_input_array(src); encoder.set_output_array(out); encoder.dispatch([mask = array::unsafe_weak_copy(mask), src = array::unsafe_weak_copy(src), out = array::unsafe_weak_copy(out)]() mutable { switch (out.dtype()) { case bool_: masked_scatter_impl(mask, src, out); break; case uint8: masked_scatter_impl(mask, src, out); break; case uint16: masked_scatter_impl(mask, src, out); break; case uint32: masked_scatter_impl(mask, src, out); break; case uint64: masked_scatter_impl(mask, src, out); break; case int8: masked_scatter_impl(mask, src, out); break; case int16: masked_scatter_impl(mask, src, out); break; case int32: masked_scatter_impl(mask, src, out); break; case int64: masked_scatter_impl(mask, src, out); break; case float16: masked_scatter_impl(mask, src, out); break; case float32: masked_scatter_impl(mask, src, out); break; case float64: masked_scatter_impl(mask, src, out); break; case bfloat16: masked_scatter_impl(mask, src, out); break; case complex64: masked_scatter_impl(mask, src, out); break; } }); } template void slice_update_impl( array& out, const array& upd, int64_t data_offset, const Strides& out_strides) { ContiguousIterator out_it(upd.shape(), out_strides, upd.ndim()); ContiguousIterator upd_it(upd); Op op; constexpr int SIMD_START = 32; T* out_ptr = out.data() + data_offset; const T* upd_ptr = upd.data(); int64_t size = upd.size(); int64_t suffix = out_it.contiguous_suffix(); if (upd.data_size() == 1) { if (suffix >= SIMD_START) { for (int64_t i = 0; i < size; i += suffix) { VectorScalar{}( out_ptr + out_it.loc, upd_ptr, out_ptr + out_it.loc, suffix); out_it.step(suffix); } } else { T update = upd_ptr[0]; for (int64_t i = 0; i < size; i++) { out_ptr[out_it.loc] = op(out_ptr[out_it.loc], update); out_it.step(); } } } else if (suffix == upd_it.contiguous_suffix() && suffix >= SIMD_START) { for (int64_t i = 0; i < size; i += suffix) { VectorVector{}( out_ptr + out_it.loc, upd_ptr + upd_it.loc, out_ptr + out_it.loc, suffix); out_it.step(suffix); upd_it.step(suffix); } } else { for (int64_t i = 0; i < size; i++) { out_ptr[out_it.loc] = op(out_ptr[out_it.loc], upd_ptr[upd_it.loc]); out_it.step(); upd_it.step(); } } } void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; auto& upd = inputs[1]; if (upd.size() == 0) { out.copy_shared_buffer(in); return; } // Check if materialization is needed auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, out_strides] = prepare_slice(out, start_indices_, strides_); // Do copy if (reduce_type_ == SliceUpdate::None) { copy_cpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), /* const std::vector& i_strides = */ upd.strides(), /* const std::vector& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, /* CopyType ctype = */ CopyType::GeneralGeneral, stream()); return; } auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(upd); encoder.set_output_array(out); encoder.dispatch([upd = array::unsafe_weak_copy(upd), out = array::unsafe_weak_copy(out), data_offset = data_offset, out_strides = std::move(out_strides), reduce_type = reduce_type_]() mutable { dispatch_all_types(out.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); switch (reduce_type) { case SliceUpdate::Sum: slice_update_impl(out, upd, data_offset, out_strides); break; case SliceUpdate::Prod: slice_update_impl( out, upd, data_offset, out_strides); break; case SliceUpdate::Max: slice_update_impl( out, upd, data_offset, out_strides); break; case SliceUpdate::Min: slice_update_impl( out, upd, data_offset, out_strides); break; case SliceUpdate::None: // Should never be here break; } }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/inverse.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template void general_inv(T* inv, int N) { int info; auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)}; // Compute LU factorization. getrf( /* m = */ &N, /* n = */ &N, /* a = */ inv, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "[Inverse::eval_cpu] LU factorization failed with error code " << info; throw std::runtime_error(ss.str()); } static const int lwork_query = -1; T workspace_size = 0; // Compute workspace size. getri( /* m = */ &N, /* a = */ nullptr, /* lda = */ &N, /* ipiv = */ nullptr, /* work = */ &workspace_size, /* lwork = */ &lwork_query, /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code " << info; throw std::runtime_error(ss.str()); } const int lwork = workspace_size; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Compute inverse. getri( /* m = */ &N, /* a = */ inv, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "[Inverse::eval_cpu] inversion failed with error code " << info; throw std::runtime_error(ss.str()); } } template void tri_inv(T* inv, int N, bool upper) { const char uplo = upper ? 'L' : 'U'; const char diag = 'N'; int info; trtri( /* uplo = */ &uplo, /* diag = */ &diag, /* N = */ &N, /* a = */ inv, /* lda = */ &N, /* info = */ &info); // zero out the other triangle if (upper) { for (int i = 0; i < N; i++) { std::fill(inv, inv + i, 0.0f); inv += N; } } else { for (int i = 0; i < N; i++) { std::fill(inv + i + 1, inv + N, 0.0f); inv += N; } } if (info != 0) { std::stringstream ss; ss << "[Inverse::eval_cpu] triangular inversion failed with error code " << info; throw std::runtime_error(ss.str()); } } template void inverse_impl( const array& a, array& inv, bool tri, bool upper, Stream stream) { // Lapack uses the column-major convention. We take advantage of the following // identity to avoid transposing (see // https://math.stackexchange.com/a/340234): // (A⁻¹)ᵀ = (Aᵀ)⁻¹ // The inverse is computed in place, so just copy the input to the output. copy_cpu( a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream); const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(inv); auto inv_ptr = inv.data(); if (tri) { encoder.dispatch([inv_ptr, N, num_matrices, upper]() { for (int i = 0; i < num_matrices; i++) { tri_inv(inv_ptr + N * N * i, N, upper); } }); } else { encoder.dispatch([inv_ptr, N, num_matrices]() { for (int i = 0; i < num_matrices; i++) { general_inv(inv_ptr + N * N * i, N); } }); } } void Inverse::eval_cpu(const std::vector& inputs, array& output) { switch (inputs[0].dtype()) { case float32: inverse_impl(inputs[0], output, tri_, upper_, stream()); break; case float64: inverse_impl(inputs[0], output, tri_, upper_, stream()); break; default: throw std::runtime_error( "[Inverse::eval_cpu] only supports float32 or float64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/jit_compiler.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/cpu/jit_compiler.h" #include #include #include #include namespace mlx::core { #ifdef _MSC_VER namespace { // Split string into array. std::vector str_split(const std::string& str, char delimiter) { std::vector tokens; std::string token; std::istringstream tokenStream(str); while (std::getline(tokenStream, token, delimiter)) { tokens.push_back(token); } return tokens; } // Get path information about MSVC. struct VisualStudioInfo { VisualStudioInfo() { #ifdef _M_ARM64 arch = "arm64"; #else arch = "x64"; #endif // Get path of Visual Studio. // Use -latest to get only the most recent installation when multiple // versions are installed, avoiding path concatenation issues. std::string vs_path = JitCompiler::exec( fmt::format( "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" " -latest -property installationPath", std::getenv("ProgramFiles(x86)"))); if (vs_path.empty()) { throw std::runtime_error("Can not find Visual Studio."); } // Trim any trailing whitespace/newlines from the path vs_path.erase( std::find_if( vs_path.rbegin(), vs_path.rend(), [](unsigned char ch) { return !std::isspace(ch); }) .base(), vs_path.end()); // Read the envs from vcvarsall. std::string envs = JitCompiler::exec( fmt::format( "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", vs_path, arch)); for (const std::string& line : str_split(envs, '\n')) { // Each line is in the format "ENV_NAME=values". auto pos = line.find_first_of('='); if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) continue; std::string name = line.substr(0, pos); std::string value = line.substr(pos + 1); if (name == "LIB") { libpaths = str_split(value, ';'); } else if (name == "VCToolsInstallDir" || name == "VCTOOLSINSTALLDIR") { cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); } } } std::string arch; std::string cl_exe; std::vector libpaths; }; const VisualStudioInfo& GetVisualStudioInfo() { static VisualStudioInfo info; return info; } } // namespace #endif // _MSC_VER std::string JitCompiler::build_command( const std::filesystem::path& dir, const std::string& source_file_name, const std::string& shared_lib_name) { #ifdef _MSC_VER const VisualStudioInfo& info = GetVisualStudioInfo(); std::string libpaths; for (const std::string& lib : info.libpaths) { libpaths += fmt::format(" /libpath:\"{0}\"", lib); } return fmt::format( "\"" "cd /D \"{0}\" && " "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " "/link /out:\"{3}\" {4} 2>&1" "\"", dir.string(), info.cl_exe, source_file_name, shared_lib_name, libpaths); #else return fmt::format( "g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1", (dir / source_file_name).string(), (dir / shared_lib_name).string()); #endif } std::string JitCompiler::exec(const std::string& cmd) { #ifdef _MSC_VER FILE* pipe = _popen(cmd.c_str(), "r"); #else FILE* pipe = popen(cmd.c_str(), "r"); #endif if (!pipe) { throw std::runtime_error("popen() failed."); } char buffer[128]; std::string ret; while (fgets(buffer, sizeof(buffer), pipe)) { ret += buffer; } // Trim trailing spaces. ret.erase( std::find_if( ret.rbegin(), ret.rend(), [](unsigned char ch) { return !std::isspace(ch); }) .base(), ret.end()); #ifdef _MSC_VER int status = _pclose(pipe); #else int status = pclose(pipe); #endif if (status == -1) { throw std::runtime_error("pclose() failed."); } #if defined(_WIN32) || defined(__FreeBSD__) int code = status; #else int code = WEXITSTATUS(status); #endif if (code != 0) { throw std::runtime_error( fmt::format( "Failed to execute command with return code {0}: \"{1}\", " "the output is: {2}", code, cmd, ret)); } return ret; } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/jit_compiler.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include namespace mlx::core { class JitCompiler { public: // Build a shell command that compiles a source code file to a shared library. static std::string build_command( const std::filesystem::path& dir, const std::string& source_file_name, const std::string& shared_lib_name); // Run a command and get its output. static std::string exec(const std::string& cmd); }; } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/lapack.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex #define lapack_complex_float_real(z) ((z).real()) #define lapack_complex_float_imag(z) ((z).imag()) #define lapack_complex_double_real(z) ((z).real()) #define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include #else #include #include #endif #if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME) // This is to work around a change in the function signatures of lapack >= 3.9.1 // where functions taking char* also include a strlen argument, see a similar // change in OpenCV: // https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57 #define MLX_LAPACK_FUNC(f) LAPACK_##f #else #define MLX_LAPACK_FUNC(f) f##_ #endif #define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(trtri) #define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_COMPLEX(heevd) #define INSTANTIATE_LAPACK_ALL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_ALL(geev) INSTANTIATE_LAPACK_ALL(gesdd) ================================================ FILE: mlx/backend/cpu/logsumexp.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" #include "mlx/types/limits.h" namespace mlx::core { namespace { using namespace mlx::core::simd; template void logsumexp(const array& in, array& out, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(in); encoder.set_output_array(out); const T* in_ptr = in.data(); T* out_ptr = out.data(); int M = in.shape().back(); int L = in.data_size() / M; encoder.dispatch([in_ptr, out_ptr, M, L]() mutable { constexpr int N = std::min(max_size, max_size); const T* current_in_ptr; for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) { // Find the maximum current_in_ptr = in_ptr; Simd vmaximum(-numeric_limits::infinity()); size_t s = M; while (s >= N) { Simd vals = load(current_in_ptr); vmaximum = maximum(vals, vmaximum); current_in_ptr += N; s -= N; } AccT maximum = max(vmaximum); while (s-- > 0) { maximum = std::max(maximum, static_cast(*current_in_ptr)); current_in_ptr++; } // Compute the normalizer and the exponentials Simd vnormalizer(0.0); current_in_ptr = in_ptr; s = M; while (s >= N) { Simd vexp = load(current_in_ptr); vexp = exp(vexp - maximum); vnormalizer = vnormalizer + vexp; current_in_ptr += N; s -= N; } AccT normalizer = sum(vnormalizer); while (s-- > 0) { AccT _exp = std::exp(*current_in_ptr - maximum); normalizer += _exp; current_in_ptr++; } // Normalize *out_ptr = std::isinf(maximum) ? static_cast(maximum) : static_cast(std::log(normalizer) + maximum); } }); } } // namespace void LogSumExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // Make sure that the last dimension is contiguous auto s = stream(); auto& encoder = cpu::get_command_encoder(s); auto ensure_contiguous = [&s, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { array x_copy = contiguous_copy_cpu(x, s); encoder.add_temporary(x_copy); return x_copy; } }; auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { out.set_data(allocator::malloc(out.nbytes())); } else { auto n = in.shape(-1); auto flags = in.flags(); auto strides = in.strides(); for (auto& s : strides) { s /= n; } bool col_contig = strides[0] == 1; for (int i = 1; col_contig && i < strides.size(); ++i) { col_contig &= (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); } flags.col_contiguous = col_contig; out.set_data( allocator::malloc(in.nbytes() / n), in.data_size() / n, std::move(strides), flags); } switch (in.dtype()) { case float32: logsumexp(in, out, stream()); break; case float16: logsumexp(in, out, stream()); break; case bfloat16: logsumexp(in, out, stream()); break; case float64: logsumexp(in, out, stream()); break; default: throw std::runtime_error( "[logsumexp] only supports floating point types"); break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/luf.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template void luf_impl( const array& a, array& lu, array& pivots, array& row_indices, Stream stream) { int M = a.shape(-2); int N = a.shape(-1); int K = std::min(M, N); // Copy a into lu and make it col contiguous auto ndim = lu.ndim(); auto flags = lu.flags(); flags.col_contiguous = ndim == 2; flags.row_contiguous = false; flags.contiguous = true; auto strides = lu.strides(); strides[ndim - 1] = M; strides[ndim - 2] = 1; lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); copy_cpu_inplace( a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral, stream); auto a_ptr = lu.data(); pivots.set_data(allocator::malloc(pivots.nbytes())); row_indices.set_data(allocator::malloc(row_indices.nbytes())); auto pivots_ptr = pivots.data(); auto row_indices_ptr = row_indices.data(); size_t num_matrices = a.size() / (M * N); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(lu); encoder.set_output_array(pivots); encoder.set_output_array(row_indices); encoder.dispatch( [a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable { int info; for (size_t i = 0; i < num_matrices; ++i) { // Compute LU factorization of A getrf( /* m */ &M, /* n */ &N, /* a */ a_ptr, /* lda */ &M, /* ipiv */ reinterpret_cast(pivots_ptr), /* info */ &info); if (info != 0) { std::stringstream ss; ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info << ((info > 0) ? " because matrix is singular" : " because argument had an illegal value"); throw std::runtime_error(ss.str()); } // Subtract 1 to get 0-based index int j = 0; for (; j < K; ++j) { pivots_ptr[j]--; row_indices_ptr[j] = j; } for (; j < M; ++j) { row_indices_ptr[j] = j; } for (int j = K - 1; j >= 0; --j) { auto piv = pivots_ptr[j]; auto t1 = row_indices_ptr[piv]; auto t2 = row_indices_ptr[j]; row_indices_ptr[j] = t1; row_indices_ptr[piv] = t2; } // Advance pointers to the next matrix a_ptr += M * N; pivots_ptr += K; row_indices_ptr += M; } }); } void LUF::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); switch (inputs[0].dtype()) { case float32: luf_impl(inputs[0], outputs[0], outputs[1], outputs[2], stream()); break; case float64: luf_impl(inputs[0], outputs[0], outputs[1], outputs[2], stream()); break; default: throw std::runtime_error( "[LUF::eval_cpu] only supports float32 or float64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/make_compiled_preamble.ps1 ================================================ # This script generates a C++ function that provides the CPU # code for use with kernel generation. # # Copyright © 2024 Apple Inc. $OUTPUT_FILE = $args[0] $CL = $args[1] $SRCDIR = $args[2] # Get command result as array. $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" # Remove empty lines. # Otherwise there will be too much empty lines making the result unreadable. $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } # Concatenate to string. $CONTENT = $CONTENT -join "`n" # Append extra content. $CONTENT = @" $($CONTENT) using namespace mlx::core; using namespace mlx::core::detail; "@ # Convert each char to ASCII code. # Unlike the unix script that outputs string literal directly, the output from # MSVC is way too large to be embedded as string and compilation will fail, so # we store it as static array instead. $CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0' $OUTPUT = @" const char* get_kernel_preamble() { static char preamble[] = { $CHARCODES }; return preamble; } "@ Set-Content -Path $OUTPUT_FILE -Value $OUTPUT ================================================ FILE: mlx/backend/cpu/make_compiled_preamble.sh ================================================ #!/bin/bash # # This script generates a C++ function that provides the CPU # code for use with kernel generation. # # Copyright © 2023-24 Apple Inc. OUTPUT_FILE=$1 GCC=$2 SRCDIR=$3 CLANG=$4 ARCH=$5 if [ "$CLANG" = "TRUE" ]; then read -r -d '' INCLUDES <<- EOM #include #include #include #include #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC #include #endif EOM CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc" else CC_FLAGS="-std=c++17" fi CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) cat << EOF > "$OUTPUT_FILE" const char* get_kernel_preamble() { return R"preamble( $INCLUDES $CONTENT using namespace mlx::core; using namespace mlx::core::detail; )preamble"; } EOF ================================================ FILE: mlx/backend/cpu/masked_mm.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/array.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template inline void mask_matrix( T* data, const mask_t* mask, int block_size, const int X, const int Y, const int64_t X_data_str, const int64_t Y_data_str, const int64_t X_mask_str, const int64_t Y_mask_str, const size_t mask_offset) { int tX = (X + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size; for (int i = 0; i < tX; i++) { for (int j = 0; j < tY; j++) { mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str]; if (do_mask != 1) { int loc_x = i * block_size; int loc_y = j * block_size; T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str; int size_x = std::min(block_size, X - loc_x); int size_y = std::min(block_size, Y - loc_y); for (int ii = 0; ii < size_x; ii++) { for (int jj = 0; jj < size_y; jj++) { if constexpr (std::is_same_v) { data_block[ii * X_data_str + jj * Y_data_str] = T(0.); } else { data_block[ii * X_data_str + jj * Y_data_str] *= do_mask; } } } } } } } template inline void segmented_mm( const T* a, const T* b, const uint32_t* segments, T* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides, size_t num_segments, const Shape& segments_shape, const Strides& segments_strides) { int ndim = a_shape.size(); Shape a_copy = a_shape; Shape b_copy = b_shape; int32_t M = a_copy[ndim - 2]; int32_t N = b_copy[ndim - 1]; for (int i = 0; i < num_segments; i++) { uint32_t k_start = segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; uint32_t k_end = segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; if (k_end <= k_start) { std::fill_n(out + i * M * N, M * N, T(0)); continue; } a_copy[ndim - 1] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start; matmul( a + k_start * a_strides[ndim - 1], b + k_start * b_strides[ndim - 2], out + i * M * N, a_transposed, b_transposed, lda, ldb, N, 1.0, 0.0, 1, a_copy, a_strides, b_copy, b_strides); } } } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() != float32) { throw std::runtime_error( "[BlockMaskedMM::eval] Currently only supports float32."); } out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; auto check_transpose = [s = stream()](const array& arr, bool do_copy, bool expand_all = false) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_cpu(arr, arr_copy, CopyType::Vector, s); return std::make_tuple(false, stx, arr_copy, true); } return std::make_tuple(false, stx, arr, false); } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_cpu(arr, arr_copy, CopyType::Vector, s); return std::make_tuple(true, sty, arr_copy, true); } return std::make_tuple(true, sty, arr, false); } else { int64_t stx = arr.shape(-1); array arr_copy = contiguous_copy_cpu(arr, s); return std::make_tuple(false, stx, arr_copy, true); } }; bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; auto [a_transposed, lda, a, a_copied] = check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_); auto [b_transposed, ldb, b, b_copied] = check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); size_t M = a.shape(-2); size_t N = b.shape(-1); size_t K = a.shape(-1); if (M == 0 || N == 0) { return; } auto& encoder = cpu::get_command_encoder(stream()); if (K == 0) { encoder.set_output_array(out); encoder.dispatch([out_ptr = out.data(), nbytes = out.nbytes()]() { std::memset(out_ptr, 0, nbytes); }); return; } auto mask_array = [](const void* mask, float* data, int block_size, int batch_idx, int X, int Y, size_t X_data_str, size_t Y_data_str, const Shape& mask_shape, const Strides& mask_strides, bool is_bool) { auto ndim = mask_shape.size(); auto mask_offset = elem_to_loc( mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx, mask_shape, mask_strides); auto X_mask_str = mask_strides[ndim - 2]; auto Y_mask_str = mask_strides[ndim - 1]; if (is_bool) { return mask_matrix( data, static_cast(mask), block_size, X, Y, X_data_str, Y_data_str, X_mask_str, Y_mask_str, mask_offset); } else { return mask_matrix( data, static_cast(mask), block_size, X, Y, X_data_str, Y_data_str, X_mask_str, Y_mask_str, mask_offset); } }; encoder.set_input_array(a); encoder.set_input_array(b); const void* a_mask_ptr = nullptr; const void* b_mask_ptr = nullptr; const void* out_mask_ptr = nullptr; Shape a_mask_shape; Shape b_mask_shape; Shape out_mask_shape; Strides a_mask_strides; Strides b_mask_strides; Strides out_mask_strides; bool a_mask_bool = false; bool b_mask_bool = false; bool out_mask_bool = false; if (has_op_mask) { auto& a_mask = inputs[inputs.size() - 2]; auto& b_mask = inputs[inputs.size() - 1]; a_mask_ptr = a_mask.data(); b_mask_ptr = b_mask.data(); a_mask_shape = a_mask.shape(); b_mask_shape = b_mask.shape(); a_mask_strides = a_mask.strides(); b_mask_strides = b_mask.strides(); a_mask_bool = (a_mask.dtype() == bool_); b_mask_bool = (b_mask.dtype() == bool_); encoder.set_input_array(a_mask); encoder.set_input_array(b_mask); } if (has_out_mask) { auto& out_mask = inputs[2]; out_mask_ptr = out_mask.data(); out_mask_bool = (out_mask.dtype() == bool_); encoder.set_input_array(out_mask); out_mask_shape = out_mask.shape(); out_mask_strides = out_mask.strides(); } encoder.set_output_array(out); auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_ptr = out.data(); size_t num_matrices = out.size() / (M * size_t(N)); auto ldc = out.shape(-1); encoder.dispatch([a_ptr, b_ptr, out_ptr, a_mask_ptr, b_mask_ptr, out_mask_ptr, has_op_mask, has_out_mask, block_size = block_size_, num_matrices, M, N, K, a_transposed = a_transposed, b_transposed = b_transposed, lda = lda, ldb = ldb, ldc, a_shape = a.shape(), a_strides = a.strides(), b_shape = b.shape(), b_strides = b.strides(), a_mask_shape = std::move(a_mask_shape), b_mask_shape = std::move(b_mask_shape), out_mask_shape = std::move(out_mask_shape), a_mask_strides = std::move(a_mask_strides), b_mask_strides = std::move(b_mask_strides), out_mask_strides = std::move(out_mask_strides), mask_array, a_mask_bool, b_mask_bool, out_mask_bool]() { for (int i = 0; i < num_matrices; ++i) { // Adjust pointer float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides); float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides); float* ci = out_ptr + M * N * i; // Zero out blocks in a and b if needed if (has_op_mask) { mask_array( a_mask_ptr, ai, block_size, i, M, K, a_transposed ? 1 : lda, a_transposed ? lda : 1, a_mask_shape, a_mask_strides, a_mask_bool); mask_array( b_mask_ptr, bi, block_size, i, K, N, b_transposed ? 1 : ldb, b_transposed ? ldb : 1, b_mask_shape, b_mask_strides, b_mask_bool); } // Do matmul cblas_sgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA b_transposed ? CblasTrans : CblasNoTrans, // transB M, N, K, 1.0, // alpha ai, lda, bi, ldb, 0.0, // beta ci, ldc); // Zero out blocks in out if (has_out_mask) { mask_array( out_mask_ptr, ci, block_size, i, M, N, N, 1, out_mask_shape, out_mask_strides, out_mask_bool); } } }); if (a_copied) { encoder.add_temporary(a); } if (b_copied) { encoder.add_temporary(b); } } void GatherMM::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() != float32) { throw std::runtime_error( "[GatherMM::eval] Currently only supports float32."); } out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; std::vector temps; auto check_transpose = [s = stream(), &temps](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (stx == arr.shape(-1) && sty == 1) { return std::make_tuple(false, stx, arr); } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); copy_cpu(arr, temps.back(), CopyType::General, s); int64_t stx = arr.shape(-1); return std::make_tuple(false, stx, temps.back()); } }; auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); size_t M = a.shape(-2); size_t N = b.shape(-1); size_t K = a.shape(-1); if (M == 0 || N == 0) { return; } auto& encoder = cpu::get_command_encoder(stream()); if (K == 0) { encoder.set_output_array(out); encoder.dispatch([out_ptr = out.data(), size = out.size()]() { std::fill_n(out_ptr, size, 0); }); return; } // Get batch dims auto batch_size_out = out.size() / (M * N); size_t matrix_stride_out = M * N; auto get_batch_dims = [](const auto& v) { return decltype(v){v.begin(), v.end() - 2}; }; auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; auto batch_shape = get_batch_dims(out.shape()); auto batch_shape_A = get_batch_dims(a.shape()); auto batch_strides_A = get_batch_dims(a.strides()); auto batch_shape_B = get_batch_dims(b.shape()); auto batch_strides_B = get_batch_dims(b.strides()); const uint32_t* lhs_indices_ptr = lhs_indices.data(); const uint32_t* rhs_indices_ptr = rhs_indices.data(); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); auto ldc = out.shape(-1); encoder.dispatch([a_ptr = a.data(), b_ptr = b.data(), out_ptr = out.data(), M, N, K, lda = lda, ldb = ldb, a_transposed = a_transposed, b_transposed = b_transposed, ldc, lhs_indices_ptr, rhs_indices_ptr, lhs_indices_shape = lhs_indices.shape(), lhs_indices_strides = lhs_indices.strides(), rhs_indices_shape = rhs_indices.shape(), rhs_indices_strides = rhs_indices.strides(), batch_size_out, matrix_stride_out, batch_shape_A = std::move(batch_shape_A), batch_shape_B = std::move(batch_shape_B), batch_strides_A = std::move(batch_strides_A), batch_strides_B = std::move(batch_strides_B)]() { for (int i = 0; i < batch_size_out; i++) { // Get index uint32_t indx_A = lhs_indices_ptr[elem_to_loc( i, lhs_indices_shape, lhs_indices_strides)]; uint32_t indx_B = rhs_indices_ptr[elem_to_loc( i, rhs_indices_shape, rhs_indices_strides)]; cblas_sgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA b_transposed ? CblasTrans : CblasNoTrans, // transB M, N, K, 1.0f, // alpha a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), lda, b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), ldb, 0.0f, // beta out_ptr + matrix_stride_out * i, ldc); } }); encoder.add_temporaries(std::move(temps)); } void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& encoder = cpu::get_command_encoder(stream()); auto check_transpose = [&s, &encoder](const array& x) { auto stx = x.strides()[x.ndim() - 2]; auto sty = x.strides()[x.ndim() - 1]; if (stx == x.shape(-1) && sty == 1) { return std::make_tuple(false, stx, x); } else if (stx == 1 && sty == x.shape(-2)) { return std::make_tuple(true, sty, x); } else { array xc(x.shape(), x.dtype(), nullptr, {}); copy_cpu(x, xc, CopyType::General, s); encoder.add_temporary(xc); int64_t stx = x.shape(-1); return std::make_tuple(false, stx, xc); } }; auto [a_transposed, lda, a] = check_transpose(inputs[0]); auto [b_transposed, ldb, b] = check_transpose(inputs[1]); auto& segments = inputs[2]; encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(segments); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), segments = array::unsafe_weak_copy(segments), out_ptr = out.data(), a_transposed = a_transposed, b_transposed = b_transposed, lda = lda, ldb = ldb]() { switch (a.dtype()) { case float64: segmented_mm( a.data(), b.data(), segments.data(), static_cast(out_ptr), a_transposed, b_transposed, lda, ldb, a.shape(), a.strides(), b.shape(), b.strides(), segments.size() / 2, segments.shape(), segments.strides()); break; case float32: segmented_mm( a.data(), b.data(), segments.data(), static_cast(out_ptr), a_transposed, b_transposed, lda, ldb, a.shape(), a.strides(), b.shape(), b.strides(), segments.size() / 2, segments.shape(), segments.strides()); break; case float16: segmented_mm( a.data(), b.data(), segments.data(), static_cast(out_ptr), a_transposed, b_transposed, lda, ldb, a.shape(), a.strides(), b.shape(), b.strides(), segments.size() / 2, segments.shape(), segments.strides()); break; case bfloat16: segmented_mm( a.data(), b.data(), segments.data(), static_cast(out_ptr), a_transposed, b_transposed, lda, ldb, a.shape(), a.strides(), b.shape(), b.strides(), segments.size() / 2, segments.shape(), segments.strides()); break; default: throw std::invalid_argument( "Segmented mm supports only real float types."); } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/matmul.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/array.h" #include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/primitives.h" namespace mlx::core { template void matmul_dispatch( const array& a, const array& b, array& out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, float alpha, float beta, Stream stream) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); T* out_ptr = out.data(); size_t ldc = out.shape(-1); size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1)); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); encoder.dispatch([a_ptr, b_ptr, out_ptr, a_transposed, b_transposed, lda, ldb, ldc, alpha, beta, batch_size, a_shape = a.shape(), a_strides = a.strides(), b_shape = b.shape(), b_strides = b.strides()]() { matmul( a_ptr, b_ptr, out_ptr, a_transposed, b_transposed, lda, ldb, ldc, alpha, beta, batch_size, a_shape, a_strides, b_shape, b_strides); }); } void matmul_general( const array& a_pre, const array& b_pre, array& out, Stream stream, float alpha = 1.0f, float beta = 0.0f) { std::vector temps; auto check_transpose = [stream, &temps](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (stx == arr.shape(-1) && sty == 1) { return std::make_tuple(false, stx, arr); } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); copy_cpu(arr, temps.back(), CopyType::General, stream); stx = arr.shape(-1); return std::make_tuple(false, stx, temps.back()); } }; auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); size_t M = a.shape(-2); size_t N = b.shape(-1); if (M == 0 || N == 0) { return; } if (out.dtype() == float32) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == float16) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == bfloat16) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == float64) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == complex64) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else { throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); } cpu::get_command_encoder(stream).add_temporaries(std::move(temps)); } void Matmul::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); if (inputs[0].shape(-1) == 0) { auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); encoder.dispatch([out_ptr = out.data(), nbytes = out.nbytes()]() { std::memset(out_ptr, 0, nbytes); }); return; } matmul_general(inputs[0], inputs[1], out, stream()); } void AddMM::eval_cpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(allocator::malloc(out.nbytes())); return; } // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { auto& c = inputs[2]; if (beta_ == 1.0f) { CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy_cpu(c, out, ctype, stream()); } else { array beta_scalar = array(beta_, c.dtype()); auto& encoder = cpu::get_command_encoder(stream()); binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream()); encoder.add_temporary(std::move(beta_scalar)); } return; } // Fill output with C auto& c = inputs[2]; CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy_cpu(c, out, ctype, stream()); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/primitives.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include #include #include "mlx/allocator.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/arange.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/threefry.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { void reshape(const array& in, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { out.set_data(allocator::malloc(out.nbytes())); copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream()); } else { shared_buffer_reshape(in, out_strides, out); } } static std::pair compute_dynamic_offset( const array& indices, const Strides& strides, const std::vector& axes, Stream stream) { array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() && (indices.data_size() * indices.itemsize()) >= offset.itemsize(); if (donate) { offset.copy_shared_buffer(indices); } else { offset.set_data(allocator::malloc(offset.itemsize())); } auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(indices); encoder.set_output_array(offset); auto compute_offset = [strides, axes, offset = offset.data()](const auto* indices) { int64_t offset_ = 0; for (int i = 0; i < axes.size(); ++i) { offset_ += indices[i] * strides[axes[i]]; } offset[0] = offset_; }; switch (indices.dtype()) { case int8: case uint8: encoder.dispatch(compute_offset, indices.data()); break; case int16: case uint16: encoder.dispatch(compute_offset, indices.data()); break; case int32: case uint32: encoder.dispatch(compute_offset, indices.data()); break; case int64: case uint64: encoder.dispatch(compute_offset, indices.data()); break; default: throw std::runtime_error("Invalid indices type."); } return {offset, donate}; } void AsStrided::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Broadcast::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void BroadcastAxes::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Copy::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void CustomTransforms::eval_cpu( const std::vector& inputs, std::vector& outputs) { eval(inputs, outputs); } void Depends::eval_cpu( const std::vector& inputs, std::vector& outputs) { eval(inputs, outputs); } void ExpandDims::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void NumberOfElements::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Slice::eval_cpu(const std::vector& inputs, array& out) { slice(inputs[0], out, start_indices_, strides_); } void Split::eval_cpu( const std::vector& inputs, std::vector& outputs) { eval(inputs, outputs); } void Squeeze::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void StopGradient::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Transpose::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Arange::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc(out.nbytes())); switch (out.dtype()) { case bool_: throw std::runtime_error("Bool type unsupported for arange."); break; case uint8: arange(start_, start_ + step_, out, out.size(), stream()); break; case uint16: arange(start_, start_ + step_, out, out.size(), stream()); break; case uint32: arange(start_, start_ + step_, out, out.size(), stream()); break; case uint64: arange(start_, start_ + step_, out, out.size(), stream()); break; case int8: arange(start_, start_ + step_, out, out.size(), stream()); break; case int16: arange(start_, start_ + step_, out, out.size(), stream()); break; case int32: arange(start_, start_ + step_, out, out.size(), stream()); break; case int64: arange(start_, start_ + step_, out, out.size(), stream()); break; case float16: arange(start_, start_ + step_, out, out.size(), stream()); break; case float32: arange(start_, start_ + step_, out, out.size(), stream()); break; case float64: arange(start_, start_ + step_, out, out.size(), stream()); break; case bfloat16: arange(start_, start_ + step_, out, out.size(), stream()); break; case complex64: arange(start_, start_ + step_, out, out.size(), stream()); break; } } void AsType::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; copy_cpu(in, out, ctype, stream()); } void Concatenate::eval_cpu(const std::vector& inputs, array& out) { std::vector sizes; sizes.push_back(0); for (auto& p : inputs) { sizes.push_back(p.shape(axis_)); } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis_] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); } } void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; constexpr size_t extra_bytes = 16384; if (in.buffer_size() <= out.nbytes() + extra_bytes && (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy_cpu(in, out, CopyType::General, stream()); } } void Flatten::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } void Unflatten::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } void Full::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; assert(in.dtype() == out.dtype()); CopyType ctype; if (in.data_size() == 1) { ctype = CopyType::Scalar; } else if (in.flags().contiguous) { ctype = CopyType::Vector; } else { ctype = CopyType::General; } copy_cpu(in, out, ctype, stream()); } void Pad::eval_cpu(const std::vector& inputs, array& out) { // Inputs must be base input array and scalar val array assert(inputs.size() == 2); auto& in = inputs[0]; auto& val = inputs[1]; // Padding value must be a scalar assert(val.size() == 1); // Padding value, input and output must be of the same type assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); // Fill output with val copy_cpu(val, out, CopyType::Scalar, stream()); // Find offset for start of input values size_t data_offset = 0; for (int i = 0; i < axes_.size(); i++) { auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; data_offset += out.strides()[ax] * low_pad_size_[i]; } // Extract slice from output where input will be pasted array out_slice(in.shape(), out.dtype(), nullptr, {}); out_slice.copy_shared_buffer( out, out.strides(), out.flags(), out_slice.size(), data_offset); // Copy input values into the slice copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); } void RandomBits::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; size_t num_keys = keys.size() / 2; size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; out.set_data(allocator::malloc(out.nbytes())); auto kptr = inputs[0].data(); auto cptr = out.data(); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(inputs[0]); encoder.set_output_array(out); encoder.dispatch([kptr, cptr, bytes_per_key, num_keys, kshape = keys.shape(), kstrides = keys.strides()]() mutable { auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) { if (4 * loc + 4 <= bytes_per_key) { reinterpret_cast(cptr)[loc] = v; } else { std::copy( reinterpret_cast(&v), reinterpret_cast(&v) + bytes_per_key - 4 * loc, cptr + 4 * loc); } }; size_t out_skip = (bytes_per_key + 4 - 1) / 4; auto half_size = out_skip / 2; bool even = out_skip % 2 == 0; for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { auto ptr = reinterpret_cast(cptr); // Get ith key auto kidx = 2 * i; auto k1_elem = elem_to_loc(kidx, kshape, kstrides); auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides); auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]); std::pair count{0, half_size + !even}; for (; count.first + 1 < half_size; count.first++, count.second++) { std::tie(ptr[count.first], ptr[count.second]) = random::threefry2x32_hash(key, count); } if (count.first < half_size) { auto rb = random::threefry2x32_hash(key, count); ptr[count.first++] = rb.first; copy_remaining(cptr, count.second, rb.second); } if (!even) { count.second = 0; copy_remaining( cptr, half_size, random::threefry2x32_hash(key, count).first); } } }); } void Reshape::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto [in_offset, donated] = compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); copy_cpu_inplace( /* const array& src = */ in, /* array& dst = */ out, /* const Shape& data_shape = */ out.shape(), /* const Strides& i_strides = */ in.strides(), /* const Strides& o_strides = */ out.strides(), /* int64_t i_offset = */ 0, /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::GeneralGeneral, stream(), /* const std::optional& dynamic_i_offset = */ in_offset, /* const std::optional& dynamic_o_offset = */ std::nullopt); if (!donated) { cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset)); } } void DynamicSliceUpdate::eval_cpu( const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; auto& upd = inputs[1]; // Copy or move src to dst auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); auto [out_offset, donated] = compute_dynamic_offset(inputs[2], out.strides(), axes_, stream()); copy_cpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), /* const std::vector& i_strides = */ upd.strides(), /* const std::vector& o_strides = */ out.strides(), /* int64_t i_offset = */ 0, /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::GeneralGeneral, stream(), /* const std::optional& dynamic_i_offset = */ std::nullopt, /* const std::optional& dynamic_o_offset = */ out_offset); if (!donated) { cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset)); } } void View::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); auto obytes = size_of(out.dtype()); // Conditions for buffer copying (disjunction): // - type size is the same // - type size is smaller and the last axis is contiguous // - the entire array is row contiguous if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || in.flags().row_contiguous) { auto strides = in.strides(); for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { strides[i] *= ibytes; strides[i] /= obytes; } out.copy_shared_buffer( in, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array( in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); tmp.set_data(allocator::malloc(tmp.nbytes())); if (in.dtype() == bool_) { auto in_tmp = array(in.shape(), uint8, nullptr, {}); in_tmp.copy_shared_buffer(in); copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream()); } else { copy_cpu_inplace(in, tmp, CopyType::General, stream()); } auto flags = out.flags(); flags.contiguous = true; flags.row_contiguous = true; auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/qrf.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template void qrf_impl(const array& a, array& q, array& r, Stream stream) { const int M = a.shape(-2); const int N = a.shape(-1); const int lda = M; size_t num_matrices = a.size() / (M * N); // Copy A to inplace input and make it col-contiguous array in(a.shape(), a.dtype(), nullptr, {}); auto flags = in.flags(); // Copy the input to be column contiguous flags.col_contiguous = num_matrices == 1; flags.row_contiguous = false; auto strides = in.strides(); strides[in.ndim() - 2] = 1; strides[in.ndim() - 1] = M; in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream); auto& encoder = cpu::get_command_encoder(stream); q.set_data(allocator::malloc(q.nbytes())); r.set_data(allocator::malloc(r.nbytes())); auto in_ptr = in.data(); auto r_ptr = r.data(); auto q_ptr = q.data(); encoder.set_input_array(in); encoder.set_output_array(q); encoder.set_output_array(r); encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { int num_reflectors = std::min(M, N); auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors); T optimal_work; int lwork = -1; int info; // Compute workspace size geqrf(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); // Update workspace size lwork = optimal_work; auto work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { // Solve geqrf( &M, &N, in_ptr + M * N * i, &lda, static_cast(tau.raw_ptr()) + num_reflectors * i, static_cast(work.raw_ptr()), &lwork, &info); } allocator::free(work); for (int i = 0; i < num_matrices; ++i) { /// num_reflectors x N for (int j = 0; j < num_reflectors; ++j) { for (int k = 0; k < j; ++k) { r_ptr[i * N * num_reflectors + j * N + k] = 0; } for (int k = j; k < N; ++k) { r_ptr[i * N * num_reflectors + j * N + k] = in_ptr[i * N * M + j + k * M]; } } } // Get work size lwork = -1; orgqr( &M, &num_reflectors, &num_reflectors, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); lwork = optimal_work; work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { // Compute Q orgqr( &M, &num_reflectors, &num_reflectors, in_ptr + M * N * i, &lda, static_cast(tau.raw_ptr()) + num_reflectors * i, static_cast(work.raw_ptr()), &lwork, &info); } for (int i = 0; i < num_matrices; ++i) { // M x num_reflectors for (int j = 0; j < M; ++j) { for (int k = 0; k < num_reflectors; ++k) { q_ptr[i * M * num_reflectors + j * num_reflectors + k] = in_ptr[i * N * M + j + k * M]; } } } // Cleanup allocator::free(work); allocator::free(tau); }); encoder.add_temporary(in); } void QRF::eval_cpu( const std::vector& inputs, std::vector& outputs) { switch (inputs[0].dtype()) { case float32: qrf_impl(inputs[0], outputs[0], outputs[1], stream()); break; case float64: qrf_impl(inputs[0], outputs[0], outputs[1], stream()); break; default: throw std::runtime_error( "[QRF::eval_cpu] only supports float32 or float64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/quantized.cpp ================================================ // Copyright © 2023 Apple Inc. #include "mlx/backend/common/quantized.h" #include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary_ops.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { array ensure_row_contiguous( const array& arr, cpu::CommandEncoder& encoder, Stream s) { if (arr.flags().row_contiguous) { return arr; } else { auto arr_cpy = contiguous_copy_cpu(arr, s); encoder.add_temporary(arr_cpy); return arr_cpy; } }; const static float FP4_LUT[16] = { +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f, -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f}; template static inline T dequantize_scale(uint8_t s) { if constexpr (group_size == 16) { return static_cast(detail::FromFP8{}(s)); } else { using FOrI = union { bfloat16_t f; uint16_t i; }; FOrI out; out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); return static_cast(out.f); } } template void extract_bits(const uint8_t* w_in, T* w_out) { static_assert(bits == 3 || bits == 5 || bits == 6); if (bits == 3) { w_out[0] = static_cast(w_in[0] & 0x7); w_out[1] = static_cast((w_in[0] & 0x38) >> 3); w_out[2] = static_cast(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2)); w_out[3] = static_cast((w_in[1] & 0xe) >> 1); w_out[4] = static_cast((w_in[1] & 0x70) >> 4); w_out[5] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); w_out[6] = static_cast((w_in[2] & 0x1c) >> 2); w_out[7] = static_cast((w_in[2] & 0xe0) >> 5); } else if (bits == 5) { w_out[0] = static_cast(w_in[0] & 0x1f); w_out[1] = static_cast(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3)); w_out[2] = static_cast((w_in[1] & 0x7c) >> 2); w_out[3] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1)); w_out[4] = static_cast(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4)); w_out[5] = static_cast((w_in[3] & 0x3e) >> 1); w_out[6] = static_cast(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2)); w_out[7] = static_cast((w_in[4] & 0xf8) >> 3); } else if (bits == 6) { w_out[0] = static_cast(w_in[0] & 0x3f); w_out[1] = static_cast(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2)); w_out[2] = static_cast(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4)); w_out[3] = static_cast((w_in[2] >> 2) & 0x3f); } } template void _qmm( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K) { constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { const uint8_t* w_local = (const uint8_t*)w; const T* scales_local = scales; const T* biases_local = biases; std::fill(result, result + N, 0); for (int k = 0; k < K; k++) { T* result_local = result; T xi = *x++; for (int n = 0; n < N; n += group_size) { T scale = *scales_local++; T bias = *biases_local++; for (int ng = 0; ng < packs_in_group; ng++) { if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) for (int p = 0; p < pack_factor; p++) { (*result_local++) += xi * (scale * wl[p] + bias); } w_local += bytes_per_pack; } else { uint8_t wi = *w_local++; #pragma clang loop unroll(full) for (int p = 0; p < pack_factor; p++) { (*result_local++) += xi * (scale * static_cast(wi & bitmask) + bias); if (bits != 8) { wi >>= bits; } } } } } } result += N; } } template void _qmm_t( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K) { constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { const uint8_t* w_local = (const uint8_t*)w; const T* scales_local = scales; const T* biases_local = biases; for (int n = 0; n < N; n++) { const T* x_local = x; T sum = 0; for (int k = 0; k < K; k += group_size) { T scale = *scales_local++; T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw++) { if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) for (int p = 0; p < pack_factor; p++) { sum += x_local[p] * (scale * wl[p] + bias); } w_local += bytes_per_pack; x_local += pack_factor; } else { uint8_t wi = *w_local++; #pragma clang loop unroll(full) for (int p = 0; p < pack_factor; p++) { sum += (*x_local++) * (scale * static_cast(wi & bitmask) + bias); if (bits != 8) { wi >>= bits; } } } } } *result = sum; result++; } x += K; } } template simd::Simd extract_bits_simd(const uint32_t* w) { constexpr int bitmask = (1 << bits) - 1; simd::Simd wi; if constexpr (bits == 4 && S == 8) { constexpr std::array shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}}; auto shifts(*(simd::Simd*)&shifts_); wi = simd::Simd(*w); wi = wi >> shifts; wi = wi & bitmask; } else if constexpr (bits == 8 && S == 8) { constexpr std::array shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}}; auto shifts(*(simd::Simd*)&shifts_); auto l = simd::Simd(*w++); auto r = simd::Simd(*w); wi = simd::Simd(l, r); wi = wi >> shifts; wi = wi & bitmask; } else { // Appease compiler.. but should never get here throw std::runtime_error("Unsupported combination for simd qmm."); } return wi; } template void _qmm_t_simd( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K) { constexpr int pack_factor = 32 / bits; constexpr int packs_in_group = group_size / pack_factor; constexpr int S = simd::max_size; static_assert( S % pack_factor == 0, "SIMD size must be divisible by pack factor"); constexpr int packs_per_simd = S / pack_factor; for (int m = 0; m < M; m++) { const uint32_t* w_local = w; const T* scales_local = scales; const T* biases_local = biases; for (int n = 0; n < N; n++) { simd::Simd acc(0); auto x_local = x; for (int k = 0; k < K; k += group_size) { T scale = *scales_local++; T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) { auto wf = simd::Simd(extract_bits_simd(w_local)); w_local += packs_per_simd; wf = wf * scale; wf = wf + bias; simd::Simd x_simd = simd::load(x_local); acc = acc + x_simd * wf; x_local += S; } } *result = T(simd::sum(acc)); result++; } x += K; } } template void _qmm_dispatch_transpose( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K, bool transposed_w) { if (transposed_w) { // the simd size must be a multiple of the number of elements per word if constexpr (32 % bits == 0 && simd::max_size % (32 / bits) == 0) { _qmm_t_simd(result, x, w, scales, biases, M, N, K); } else { _qmm_t(result, x, w, scales, biases, M, N, K); } } else { _qmm(result, x, w, scales, biases, M, N, K); } } template void _qmm_dispatch_group( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K, int group_size, bool transposed_w) { switch (group_size) { case 32: _qmm_dispatch_transpose( result, x, w, scales, biases, M, N, K, transposed_w); break; case 64: _qmm_dispatch_transpose( result, x, w, scales, biases, M, N, K, transposed_w); break; case 128: _qmm_dispatch_transpose( result, x, w, scales, biases, M, N, K, transposed_w); break; default: throw std::invalid_argument( "Quantization group size must be 32, 64 or 128."); } } template void _qmm_dispatch_typed( T* result, const T* x, const uint32_t* w, const T* scales, const T* biases, int M, int N, int K, int group_size, int bits, bool transposed_w) { switch (bits) { case 2: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; case 3: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; case 4: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; case 5: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; case 6: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; case 8: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; default: throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8."); } } template void _qmm_dispatch_typed( array& out, const array& x, const array& w, const array& scales, const array& biases, int bits, int group_size, bool transposed_w) { int K = x.shape(-1); int M = x.ndim() > 1 ? x.shape(-2) : 1; int N = out.shape(-1); int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int batch_size = x.size() / (K * M); auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); auto scales_ptr = scales.data(); auto biases_ptr = biases.data(); for (int i = 0; i < batch_size; i++) { _qmm_dispatch_typed( out_ptr + i * M * N, x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()), biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()), M, N, K, bits, group_size, transposed_w); } } void _qmm_dispatch( array& out, const array& x, const array& w, const array& scales, const array& biases, int bits, int group_size, bool transposed_w) { switch (x.dtype()) { case float32: _qmm_dispatch_typed( out, x, w, scales, biases, bits, group_size, transposed_w); break; case float16: _qmm_dispatch_typed( out, x, w, scales, biases, bits, group_size, transposed_w); break; case bfloat16: _qmm_dispatch_typed( out, x, w, scales, biases, bits, group_size, transposed_w); break; default: throw std::invalid_argument( "[quantized_matmul] only floating types are supported"); } } template void fp_qmm( T* result, const T* x, const uint32_t* w, const uint8_t* scales, int M, int N, int K) { constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { const uint8_t* w_local = (const uint8_t*)w; const uint8_t* scales_local = scales; std::fill(result, result + N, 0); for (int k = 0; k < K; k++) { T* result_local = result; T xi = *x++; for (int n = 0; n < N; n += group_size) { T scale = dequantize_scale(*scales_local++); for (int ng = 0; ng < packs_in_group; ng++) { if constexpr (bits == 4) { (*result_local++) += xi * scale * static_cast(FP4_LUT[w_local[0] & 0xf]); (*result_local++) += xi * scale * static_cast(FP4_LUT[(w_local[0] >> 4) & 0xf]); } else { (*result_local++) += xi * scale * static_cast(detail::FromFP8{}(w_local[0])); } w_local++; } } } result += N; } } template void fp_qmm_t( T* result, const T* x, const uint32_t* w, const uint8_t* scales, int M, int N, int K) { constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { const uint8_t* w_local = (const uint8_t*)w; const uint8_t* scales_local = scales; for (int n = 0; n < N; n++) { const T* x_local = x; T sum = 0; for (int k = 0; k < K; k += group_size) { T scale = dequantize_scale(*scales_local++); T gsum = 0; for (int kw = 0; kw < packs_in_group; kw++) { if constexpr (bits == 4) { gsum += (*x_local++) * static_cast(FP4_LUT[w_local[0] & 0xf]); gsum += (*x_local++) * static_cast(FP4_LUT[(w_local[0] >> 4) & 0xf]); } else { gsum += (*x_local++) * static_cast(detail::FromFP8{}(w_local[0])); } w_local++; } sum += scale * gsum; } *result = sum; result++; } x += K; } } template simd::Simd fp_extract_bits_simd(const uint32_t* w) { if constexpr (S == 8 && bits == 4) { constexpr std::array shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}}; auto shifts(*(simd::Simd*)&shifts_); auto wi = simd::Simd(*w); wi = wi >> shifts; wi = wi & 0xf; simd::Simd w_out; for (int i = 0; i < S; ++i) { w_out[i] = FP4_LUT[wi[i]]; } return w_out; } else if constexpr (S == 8 && bits == 8) { auto w_out = simd::load(reinterpret_cast(w)); return detail::FromFP8{}(w_out); } else { // Appease compiler.. but should never get here throw std::runtime_error("Unsupported combination for simd qmm."); } } template void fp_qmm_t_simd( T* result, const T* x, const uint32_t* w, const uint8_t* scales, int M, int N, int K) { constexpr int pack_factor = get_pack_factor(bits, 32); constexpr int packs_in_group = group_size / pack_factor; constexpr int S = simd::max_size; static_assert( S % pack_factor == 0, "SIMD size must be divisible by pack factor"); constexpr int packs_per_simd = S / pack_factor; for (int m = 0; m < M; m++) { const uint32_t* w_local = w; const uint8_t* scales_local = scales; for (int n = 0; n < N; n++) { simd::Simd acc(0); auto x_local = x; for (int k = 0; k < K; k += group_size) { T scale = dequantize_scale(*scales_local++); simd::Simd g_acc(0); for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) { // Extract bits auto wf = fp_extract_bits_simd(w_local); w_local += packs_per_simd; simd::Simd x_simd = simd::load(x_local); g_acc = g_acc + x_simd * wf; x_local += S; } acc = acc + scale * g_acc; } *result = T(simd::sum(acc)); result++; } x += K; } } template void fp_qmm_dispatch_transpose( T* result, const T* x, const uint32_t* w, const uint8_t* scales, int M, int N, int K, bool transposed_w) { if (transposed_w) { // the simd size must be a multiple of the number of elements per word if constexpr (simd::max_size % 8 == 0) { fp_qmm_t_simd(result, x, w, scales, M, N, K); } else { fp_qmm_t(result, x, w, scales, M, N, K); } } else { fp_qmm(result, x, w, scales, M, N, K); } } template void fp_qmm_dispatch_mode( array& out, const array& x, const array& w, const array& scales, bool transposed_w) { int K = x.shape(-1); int M = x.ndim() > 1 ? x.shape(-2) : 1; int N = out.shape(-1); int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int batch_size = x.size() / (K * M); auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); auto scales_ptr = scales.data(); for (int i = 0; i < batch_size; i++) { fp_qmm_dispatch_transpose( out_ptr + i * M * N, x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()), M, N, K, transposed_w); } } template void fp_qmm_dispatch_typed( array& out, const array& x, const array& w, const array& scales, int group_size, int bits, bool transposed_w) { if (bits == 8) { fp_qmm_dispatch_mode(out, x, w, scales, transposed_w); } else if (group_size == 32) { fp_qmm_dispatch_mode(out, x, w, scales, transposed_w); } else { fp_qmm_dispatch_mode(out, x, w, scales, transposed_w); } } void fp_qmm_dispatch( array& out, const array& x, const array& w, const array& scales, int group_size, int bits, bool transposed_w) { switch (x.dtype()) { case bfloat16: fp_qmm_dispatch_typed( out, x, w, scales, group_size, bits, transposed_w); break; case float16: fp_qmm_dispatch_typed( out, x, w, scales, group_size, bits, transposed_w); break; case float32: fp_qmm_dispatch_typed( out, x, w, scales, group_size, bits, transposed_w); break; default: throw std::invalid_argument( "[quantized_matmul] only floating types are supported"); } } template void _bs_qmm_dispatch_typed( array& out, const array& x, const array& w, const array& scales, const array& biases, const array& lhs_indices, const array& rhs_indices, int bits, int group_size, bool transposed_w) { int K = x.shape(-1); int M = x.shape(-2); int N = out.shape(-1); int w_els = w.shape(-1) * w.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2); auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); auto scales_ptr = scales.data(); auto biases_ptr = biases.data(); auto lhs_indices_ptr = lhs_indices.data(); auto rhs_indices_ptr = rhs_indices.data(); for (int i = 0; i < lhs_indices.size(); i++) { int x_idx = lhs_indices_ptr[elem_to_loc( i, lhs_indices.shape(), lhs_indices.strides())]; int w_idx = rhs_indices_ptr[elem_to_loc( i, rhs_indices.shape(), rhs_indices.strides())]; _qmm_dispatch_typed( out_ptr + i * M * N, x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), scales_ptr + elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()), biases_ptr + elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()), M, N, K, bits, group_size, transposed_w); } } void _bs_qmm_dispatch( array& out, const array& x, const array& w, const array& scales, const array& biases, const array& lhs_indices, const array& rhs_indices, int bits, int group_size, bool transposed_w) { switch (x.dtype()) { case float32: _bs_qmm_dispatch_typed( out, x, w, scales, biases, lhs_indices, rhs_indices, bits, group_size, transposed_w); break; case float16: _bs_qmm_dispatch_typed( out, x, w, scales, biases, lhs_indices, rhs_indices, bits, group_size, transposed_w); break; case bfloat16: _bs_qmm_dispatch_typed( out, x, w, scales, biases, lhs_indices, rhs_indices, bits, group_size, transposed_w); break; default: throw std::invalid_argument( "[quantized_matmul] only floating types are supported"); } } template void fp_bs_qmm_dispatch_mode( array& out, const array& x, const array& w, const array& scales, const array& lhs_indices, const array& rhs_indices, bool transposed_w) { int K = x.shape(-1); int M = x.shape(-2); int N = out.shape(-1); int w_els = w.shape(-1) * w.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2); auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); auto scales_ptr = scales.data(); auto lhs_indices_ptr = lhs_indices.data(); auto rhs_indices_ptr = rhs_indices.data(); for (int i = 0; i < lhs_indices.size(); i++) { int x_idx = lhs_indices_ptr[elem_to_loc( i, lhs_indices.shape(), lhs_indices.strides())]; int w_idx = rhs_indices_ptr[elem_to_loc( i, rhs_indices.shape(), rhs_indices.strides())]; fp_qmm_dispatch_transpose( out_ptr + i * M * N, x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), scales_ptr + elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()), M, N, K, transposed_w); } } template void fp_bs_qmm_dispatch_typed( array& out, const array& x, const array& w, const array& scales, const array& lhs_indices, const array& rhs_indices, int group_size, int bits, bool transposed_w) { if (bits == 8) { fp_bs_qmm_dispatch_mode( out, x, w, scales, lhs_indices, rhs_indices, transposed_w); } else if (group_size == 32) { fp_bs_qmm_dispatch_mode( out, x, w, scales, lhs_indices, rhs_indices, transposed_w); } else { fp_bs_qmm_dispatch_mode( out, x, w, scales, lhs_indices, rhs_indices, transposed_w); } } void fp_bs_qmm_dispatch( array& out, const array& x, const array& w, const array& scales, const array& lhs_indices, const array& rhs_indices, int group_size, int bits, bool transposed_w) { switch (x.dtype()) { case float32: fp_bs_qmm_dispatch_typed( out, x, w, scales, lhs_indices, rhs_indices, group_size, bits, transposed_w); break; case float16: fp_bs_qmm_dispatch_typed( out, x, w, scales, lhs_indices, rhs_indices, group_size, bits, transposed_w); break; case bfloat16: fp_bs_qmm_dispatch_typed( out, x, w, scales, lhs_indices, rhs_indices, group_size, bits, transposed_w); break; default: throw std::invalid_argument( "[quantized_matmul] only floating types are supported"); } } } // namespace void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; auto& encoder = cpu::get_command_encoder(stream()); auto x = ensure_row_contiguous(x_pre, encoder, stream()); auto w = ensure_row_contiguous(w_pre, encoder, stream()); auto scales = ensure_row_contiguous(scales_pre, encoder, stream()); out.set_data(allocator::malloc(out.nbytes())); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); encoder.set_output_array(out); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[3], encoder, stream()); encoder.set_input_array(biases); encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), w = array::unsafe_weak_copy(w), scales = array::unsafe_weak_copy(scales), biases = array::unsafe_weak_copy(biases), group_size_ = group_size_, bits_ = bits_, transpose_ = transpose_]() mutable { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); }); } else { encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), w = array::unsafe_weak_copy(w), scales = array::unsafe_weak_copy(scales), group_size_ = group_size_, bits_ = bits_, transpose_ = transpose_]() mutable { fp_qmm_dispatch(out, x, w, scales, group_size_, bits_, transpose_); }); } } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; auto& lhs_indices = inputs[inputs.size() - 2]; auto& rhs_indices = inputs[inputs.size() - 1]; auto& encoder = cpu::get_command_encoder(stream()); auto ensure_row_contiguous_last_dims = [s = stream(), &encoder](const array& arr) { auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_1 = arr.strides()[arr.ndim() - 1]; if (stride_0 == arr.shape(-1) && stride_1 == 1) { return arr; } else { auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); copy_cpu(arr, arr_cpy, CopyType::General, s); encoder.add_temporary(arr_cpy); return arr_cpy; } }; auto x = ensure_row_contiguous_last_dims(x_pre); auto w = ensure_row_contiguous_last_dims(w_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre); out.set_data(allocator::malloc(out.nbytes())); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous_last_dims(inputs[3]); encoder.set_input_array(biases); encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), w = array::unsafe_weak_copy(w), scales = array::unsafe_weak_copy(scales), biases = array::unsafe_weak_copy(biases), lhs_indices = array::unsafe_weak_copy(lhs_indices), rhs_indices = array::unsafe_weak_copy(rhs_indices), group_size_ = group_size_, bits_ = bits_, transpose_ = transpose_]() mutable { _bs_qmm_dispatch( out, x, w, scales, biases, lhs_indices, rhs_indices, group_size_, bits_, transpose_); }); } else { encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), w = array::unsafe_weak_copy(w), scales = array::unsafe_weak_copy(scales), lhs_indices = array::unsafe_weak_copy(lhs_indices), rhs_indices = array::unsafe_weak_copy(rhs_indices), group_size_ = group_size_, bits_ = bits_, transpose_ = transpose_]() mutable { fp_bs_qmm_dispatch( out, x, w, scales, lhs_indices, rhs_indices, group_size_, bits_, transpose_); }); } } uint8_t to_fp8_e8m0(float x) { if (!std::isfinite(x)) { return 0xFF; } if (x < 0.0f) { return 0x00; } float le = std::log2(x); int n = int(std::round(le)); n = n < -127 ? -127 : n; n = n > 127 ? 127 : n; return static_cast(n + 127); } uint8_t to_fp4_e2m1(float x) { if (std::isnan(x)) { return 0x7; } const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0; x = std::abs(x); uint8_t bits; if (x > 5.0f) { bits = 0x7; } else if (x >= 3.5f) { bits = 0x6; } else if (x > 2.5f) { bits = 0x5; } else if (x >= 1.75f) { bits = 0x4; } else if (x > 1.25f) { bits = 0x3; } else if (x >= 0.75f) { bits = 0x2; } else if (x > 0.25f) { bits = 0x1; } else { bits = 0x0; } return bits | sign_bit; } template void fp_quantize_dequantize( const array& w_arr, array& out_arr, int bits, int group_size, size_t w_size) { auto w = w_arr.data(); auto out = out_arr.data(); size_t n_groups = w_size / group_size; for (size_t i = 0; i < n_groups; ++i) { size_t idx = i * group_size; float scale = -std::numeric_limits::infinity(); for (int j = 0; j < group_size; ++j) { scale = std::max(scale, std::abs(w[idx + j])); } scale /= bits == 4 ? 6.0f : 448.0f; if (group_size == 16) { scale = dequantize_scale(detail::ToFP8()(scale)); } else { scale = dequantize_scale(to_fp8_e8m0(scale)); } for (int j = 0; j < group_size; ++j) { float w_el = scale == 0 ? 0.0f : w[idx + j] / scale; float output; if (bits == 8) { output = detail::FromFP8()(detail::ToFP8()(w_el)); } else { output = FP4_LUT[to_fp4_e2m1(w_el)]; } out[idx + j] = static_cast(scale * output); } } } void dispatch_quantize_dequantize( const array& w, array& out, int bits, int group_size) { if (w.dtype() == float16) { fp_quantize_dequantize(w, out, bits, group_size, w.size()); } else if (w.dtype() == bfloat16) { fp_quantize_dequantize(w, out, bits, group_size, w.size()); } else if (w.dtype() == float32) { fp_quantize_dequantize(w, out, bits, group_size, w.size()); } else { throw std::runtime_error( "[quantize_dequantize] Only supports floating point inputs"); } } template void quantize( const T* w, U* out, T* scales, T* biases, int bits, int group_size, size_t w_size) { float n_bins = (1 << bits) - 1; float eps = 1e-7; bool power_of_2_bits = is_power_of_2(bits); int el_per_int = get_pack_factor(bits, 32); int bytes_per_pack = get_bytes_per_pack(bits); int int_per_group = group_size * bytes_per_pack / el_per_int; size_t n_groups = w_size / group_size; for (size_t i = 0; i < n_groups; ++i) { size_t w_idx = i * group_size; float w_min = std::numeric_limits::infinity(); float w_max = -w_min; for (int j = 0; j < group_size; ++j) { w_max = std::max(w_max, (float)w[w_idx + j]); w_min = std::min(w_min, (float)w[w_idx + j]); } bool mask = std::abs(w_min) > std::abs(w_max); float scale = std::max((w_max - w_min) / n_bins, eps); scale = mask ? scale : -scale; float edge = mask ? w_min : w_max; float q0 = std::rint(edge / scale); float bias = 0; if (q0 != 0) { scale = edge / q0; bias = edge; } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { uint64_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { float w_el = w[w_idx + j * el_per_int + k]; w_el = std::rint((w_el - bias) / scale); w_el = std::min(std::max(w_el, 0.0f), n_bins); out_el |= static_cast(w_el) << (k * bits); } if (power_of_2_bits) { out[out_idx + j] = out_el; } else if (bits == 5) { out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24; out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32; } else { out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; } } scales[i] = static_cast(scale); biases[i] = static_cast(bias); } } template void dispatch_quantize( const array& w, array& out, array& scales, array& biases, int bits, int group_size) { auto w_ptr = w.data(); auto out_ptr = out.data(); auto scales_ptr = scales.data(); auto biases_ptr = biases.data(); quantize( w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); } void fast::Quantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto& encoder = cpu::get_command_encoder(stream()); auto w = ensure_row_contiguous(inputs[0], encoder, stream()); auto& out = outputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& scales = outputs[1]; auto& biases = outputs[2]; scales.set_data(allocator::malloc(scales.nbytes())); biases.set_data(allocator::malloc(biases.nbytes())); encoder.set_input_array(w); encoder.set_input_array(scales); encoder.set_input_array(biases); encoder.set_output_array(out); encoder.dispatch([w = array::unsafe_weak_copy(w), out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales), biases = array::unsafe_weak_copy(biases), group_size_ = group_size_, bits_ = bits_]() mutable { if (w.dtype() == float16) { if (is_power_of_2(bits_)) { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } else { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } } else if (w.dtype() == bfloat16) { if (is_power_of_2(bits_)) { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } else { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } } else if (w.dtype() == float32) { if (is_power_of_2(bits_)) { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } else { dispatch_quantize( w, out, scales, biases, bits_, group_size_); } } else { throw std::runtime_error( "[fast::Quantize::eval_cpu] Only supports floating point inputs"); } }); } void fast::ConvertFP8::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto& in = inputs[0]; auto& out = outputs[0]; set_unary_output_data(in, out); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), to_fp8 = to_fp8_]() mutable { if (to_fp8) { switch (in.dtype()) { case float16: unary_op(in, out, detail::ToFP8()); break; case bfloat16: unary_op(in, out, detail::ToFP8()); break; default: unary_op(in, out, detail::ToFP8()); break; } } else { switch (out.dtype()) { case float16: unary_op(in, out, detail::FromFP8()); break; case bfloat16: unary_op(in, out, detail::FromFP8()); break; default: unary_op(in, out, detail::FromFP8()); break; } } }); } void QQMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& encoder = cpu::get_command_encoder(stream()); bool w_quantized = (inputs[1].dtype() == uint32); if (w_quantized && inputs[0].shape(-2) == 1) { bool donate_x = inputs[0].is_donatable(); auto x = ensure_row_contiguous(inputs[0], encoder, stream()); auto w = ensure_row_contiguous(inputs[1], encoder, stream()); auto scales = ensure_row_contiguous(inputs[2], encoder, stream()); out.set_data(allocator::malloc(out.nbytes())); // If x is a copy it should be donatable donate_x |= x.is_donatable(); auto xhat = donate_x ? x : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); if (!donate_x) { encoder.add_temporary(xhat); } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); encoder.set_output_array(out); encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), xhat = array::unsafe_weak_copy(xhat), w = array::unsafe_weak_copy(w), scales = array::unsafe_weak_copy(scales), group_size_ = group_size_, bits_ = bits_]() mutable { dispatch_quantize_dequantize(x, xhat, bits_, group_size_); fp_qmm_dispatch(out, xhat, w, scales, group_size_, bits_, true); }); return; } else { throw std::runtime_error("[QQMatmul] NYI for the general case"); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/reduce.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include "mlx/backend/common/reduce.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { template struct Limits { static const U max; static const U min; }; #define instantiate_default_limit(type) \ template <> \ struct Limits { \ static constexpr type max = std::numeric_limits::max(); \ static constexpr type min = std::numeric_limits::min(); \ }; instantiate_default_limit(uint8_t); instantiate_default_limit(uint16_t); instantiate_default_limit(uint32_t); instantiate_default_limit(uint64_t); instantiate_default_limit(int8_t); instantiate_default_limit(int16_t); instantiate_default_limit(int32_t); instantiate_default_limit(int64_t); #define instantiate_float_limit(type) \ template <> \ struct Limits { \ static const type max; \ static const type min; \ }; instantiate_float_limit(float16_t); instantiate_float_limit(bfloat16_t); instantiate_float_limit(float); instantiate_float_limit(double); instantiate_float_limit(complex64_t); template <> struct Limits { static constexpr bool max = true; static constexpr bool min = false; }; const float Limits::max = std::numeric_limits::infinity(); const float Limits::min = -std::numeric_limits::infinity(); const bfloat16_t Limits::max = std::numeric_limits::infinity(); const bfloat16_t Limits::min = -std::numeric_limits::infinity(); const float16_t Limits::max = std::numeric_limits::infinity(); const float16_t Limits::min = -std::numeric_limits::infinity(); const double Limits::max = std::numeric_limits::infinity(); const double Limits::min = -std::numeric_limits::infinity(); const complex64_t Limits::max = std::numeric_limits::infinity(); const complex64_t Limits::min = -std::numeric_limits::infinity(); template void strided_reduce( const T* x, U* accumulator, int size, size_t stride, Op op) { constexpr int N = std::min(simd::max_size, simd::max_size); for (int i = 0; i < size; i++) { U* moving_accumulator = accumulator; auto s = stride; while (s >= N) { auto acc = simd::load(moving_accumulator); auto v = simd::Simd(simd::load(x)); simd::store(moving_accumulator, op(acc, v)); moving_accumulator += N; x += N; s -= N; } while (s-- > 0) { *moving_accumulator = op(*moving_accumulator, *x); moving_accumulator++; x++; } } }; template void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) { constexpr int N = std::min(simd::max_size, simd::max_size); simd::Simd accumulator_v(init); while (size >= N) { accumulator_v = op(accumulator_v, simd::Simd(simd::load(x))); x += N; size -= N; } *accumulator = op(*accumulator, op(accumulator_v)); while (size-- > 0) { *accumulator = op(*accumulator, *x); x++; } } // Helper for the ndimensional strided loop void nd_loop( std::function callback, const Shape& shape, const Strides& strides) { std::function loop_inner; loop_inner = [&](int dim, int offset) { if (dim < shape.size() - 1) { auto size = shape[dim]; auto stride = strides[dim]; for (int i = 0; i < size; i++) { loop_inner(dim + 1, offset + i * stride); } } else { auto size = shape[dim]; auto stride = strides[dim]; for (int i = 0; i < size; i++) { callback(offset + i * stride); } } }; loop_inner(0, 0); } template void reduction_op( const array& x, array& out, const std::vector& axes, U init) { ReductionPlan plan = get_reduction_plan(x, axes); auto in_ptr = x.data(); auto out_ptr = out.data(); if (plan.type == ContiguousAllReduce) { *out_ptr = init; contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init); return; } if (plan.type == ContiguousReduce && plan.shape.size() == 1) { int reduction_size = plan.shape[0]; for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) { *out_ptr = init; contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); } return; } if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) { int reduction_size = plan.shape.back(); plan.shape.pop_back(); plan.strides.pop_back(); // Unrolling the following loop (and implementing it in order for // ContiguousReduce) should hold extra performance boost. auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); *out_ptr = init; contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init); } } else { for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); *out_ptr = init; nd_loop( [&](int extra_offset) { contiguous_reduce( in_ptr + offset + extra_offset, out_ptr, reduction_size, Op{}, init); }, plan.shape, plan.strides); } } return; } if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) { int reduction_size = plan.shape.back(); size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); for (int i = 0; i < out.size(); i += reduction_stride) { std::fill_n(out_ptr, reduction_stride, init); strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); in_ptr += reduction_stride * reduction_size; out_ptr += reduction_stride; } return; } if (plan.type == GeneralStridedReduce || plan.type == ContiguousStridedReduce) { int reduction_size = plan.shape.back(); size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); std::fill_n(out_ptr, reduction_stride, init); strided_reduce( in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{}); out_ptr += reduction_stride; } } else { for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); std::fill_n(out_ptr, reduction_stride, init); nd_loop( [&](int extra_offset) { strided_reduce( in_ptr + offset + extra_offset, out_ptr, reduction_size, reduction_stride, Op{}); }, plan.shape, plan.strides); out_ptr += reduction_stride; } } return; } if (plan.type == GeneralReduce) { auto [shape, strides] = shapes_without_reduction_axes(x, axes); for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); U val = init; nd_loop( [&](int extra_offset) { val = Op{}(val, *(in_ptr + offset + extra_offset)); }, plan.shape, plan.strides); *out_ptr = val; } } } struct AndReduce { template bool operator()(bool x, T y) { return x & (y != 0); } bool operator()(bool x, bool y) { return x & y; } template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x & (y != 0); }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x & y; }; template bool operator()(simd::Simd x) { return simd::all(x); }; }; struct OrReduce { template bool operator()(bool x, T y) { return x | (y != 0); } bool operator()(bool x, bool y) { return x | y; } template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x | (y != 0); }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x | y; }; template bool operator()(simd::Simd x) { return simd::any(x); }; }; struct MaxReduce { template T operator()(T y, T x) { return (*this)(simd::Simd(x), simd::Simd(y)).value; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return simd::maximum(x, y); }; template std::enable_if_t, T> operator()(simd::Simd x) { return simd::max(x); }; template std::enable_if_t, T> operator()(simd::Simd x) { if (simd::any(x != x)) { return static_cast(NAN); } return simd::max(x); }; }; struct MinReduce { template T operator()(T y, T x) { return (*this)(simd::Simd(x), simd::Simd(y)).value; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return simd::minimum(x, y); }; template std::enable_if_t, T> operator()(simd::Simd x) { return simd::min(x); }; template std::enable_if_t, T> operator()(simd::Simd x) { if (simd::any(x != x)) { return static_cast(NAN); } return simd::min(x); }; }; struct SumReduce { template U operator()(U y, T x) { return x + y; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return y + x; }; template T operator()(simd::Simd x) { return simd::sum(x); }; }; struct ProdReduce { template U operator()(U y, T x) { return x * y; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x * y; }; template T operator()(simd::Simd x) { return simd::prod(x); }; }; template void reduce_dispatch_and_or( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::And) { reduction_op(in, out, axes, true); } else { reduction_op(in, out, axes, false); } } template void reduce_dispatch_sum_prod( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::Sum) { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { reduction_op(in, out, axes, 0); } else { reduction_op(in, out, axes, 0); } } else { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { reduction_op(in, out, axes, 1); } else { reduction_op(in, out, axes, 1); } } } template void reduce_dispatch_min_max( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::Max) { auto init = Limits::min; reduction_op(in, out, axes, init); } else { auto init = Limits::max; reduction_op(in, out, axes, init); } } void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), reduce_type_ = reduce_type_, axes_ = axes_]() mutable { switch (reduce_type_) { case Reduce::And: case Reduce::Or: { switch (in.dtype()) { case bool_: case uint8: case int8: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case int16: case uint16: case float16: case bfloat16: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case uint32: case int32: case float32: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case uint64: case int64: case float64: case complex64: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; } break; } case Reduce::Sum: case Reduce::Prod: { switch (in.dtype()) { case bool_: case uint8: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case uint16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case uint32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case uint64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int8: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case bfloat16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case complex64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; } break; } case Reduce::Max: case Reduce::Min: { switch (in.dtype()) { case bool_: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint8: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case bfloat16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case complex64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; } break; } } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/scan.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template void contiguous_scan( const T* input, U* output, int count, int stride, bool reverse, bool inclusive, const Op& op, U init) { if (!reverse) { if (inclusive) { for (int i = 0; i < count; i++) { *output = *input; for (int j = 1; j < stride; j++) { input++; output++; *output = op(*(output - 1), *input); } output++; input++; } } else { for (int i = 0; i < count; i++) { *output = init; for (int j = 1; j < stride; j++) { *(output + 1) = op(*output, *input); input++; output++; } output++; input++; } } } else { if (inclusive) { for (int i = 0; i < count; i++) { output += stride - 1; input += stride - 1; *output = *input; for (int j = 1; j < stride; j++) { input--; output--; *output = op(*(output + 1), *input); } output += stride; input += stride; } } else { for (int i = 0; i < count; i++) { output += stride - 1; input += stride - 1; *output = init; for (int j = 1; j < stride; j++) { *(output - 1) = op(*output, *input); input--; output--; } output += stride; input += stride; } } } }; template void strided_scan( const T* input, U* output, int count, int size, int stride, bool reverse, bool inclusive, const Op& op, U init) { // TODO: Vectorize the following naive implementation if (!reverse) { if (inclusive) { for (int i = 0; i < count; i++) { std::copy(input, input + stride, output); output += stride; input += stride; for (int j = 1; j < size; j++) { for (int k = 0; k < stride; k++) { *output = op(*(output - stride), *input); output++; input++; } } } } else { for (int i = 0; i < count; i++) { std::fill(output, output + stride, init); output += stride; input += stride; for (int j = 1; j < size; j++) { for (int k = 0; k < stride; k++) { *output = op(*(output - stride), *(input - stride)); output++; input++; } } } } } else { if (inclusive) { for (int i = 0; i < count; i++) { output += (size - 1) * stride; input += (size - 1) * stride; std::copy(input, input + stride, output); for (int j = 1; j < size; j++) { for (int k = 0; k < stride; k++) { output--; input--; *output = op(*(output + stride), *input); } } output += size * stride; input += size * stride; } } else { for (int i = 0; i < count; i++) { output += (size - 1) * stride; input += (size - 1) * stride; std::fill(output, output + stride, init); for (int j = 1; j < size; j++) { for (int k = 0; k < stride; k++) { output--; input--; *output = op(*(output + stride), *(input + stride)); } } output += size * stride; input += size * stride; } } } }; template void scan_op( const array& in, array& out, int axis, bool reverse, bool inclusive, const Op& op, U init) { if (in.flags().row_contiguous) { if (in.strides()[axis] == 1) { contiguous_scan( in.data(), out.data(), in.size() / in.shape(axis), in.shape(axis), reverse, inclusive, op, init); } else { strided_scan( in.data(), out.data(), in.size() / in.shape(axis) / in.strides()[axis], in.shape(axis), in.strides()[axis], reverse, inclusive, op, init); } } else { throw std::runtime_error("Scan op supports only contiguous inputs"); } } template void scan_dispatch( Scan::ReduceType rtype, const array& in, array& out, int axis, bool reverse, bool inclusive) { switch (rtype) { case Scan::Sum: { auto op = [](U y, T x) { return y + x; }; auto init = static_cast(0); scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Prod: { auto op = [](U y, T x) { return y * x; }; auto init = static_cast(1); scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Min: { auto op = [](U y, T x) { return x < y ? x : y; }; auto init = (issubdtype(in.dtype(), floating)) ? static_cast(std::numeric_limits::infinity()) : std::numeric_limits::max(); scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Max: { auto op = [](U y, T x) { return x < y ? y : x; }; auto init = (issubdtype(in.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::min(); scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::LogAddExp: { auto op = [](U a, T b) { return detail::LogAddExp{}(a, static_cast(b)); }; auto init = (issubdtype(in.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::min(); scan_op(in, out, axis, reverse, inclusive, op, init); break; } } } } // namespace void Scan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& encoder = cpu::get_command_encoder(stream()); // Ensure contiguity auto in = inputs[0]; if (!in.flags().row_contiguous) { in = contiguous_copy_cpu(in, stream()); encoder.add_temporary(in); } out.set_data(allocator::malloc(out.nbytes())); encoder.set_input_array(in); encoder.set_output_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), axis_ = axis_, reduce_type_ = reduce_type_, reverse_ = reverse_, inclusive_ = inclusive_]() mutable { switch (in.dtype()) { case bool_: { // We could do a full dtype x dtype switch but this is the only case // where we accumulate in a different type, for now. // // TODO: If we add the option to accumulate floats in higher precision // floats perhaps we should add the full all-to-all dispatch. if (reduce_type_ == Scan::Sum && out.dtype() == int32) { scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); } else { scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); } break; } case uint8: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case uint16: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case uint32: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case uint64: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case int8: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case int16: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case int32: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case int64: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case float16: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case float32: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case float64: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case bfloat16: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; case complex64: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/select.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/ternary.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template void select_op( const array& a, const array& b, const array& c, array& out, Op op, Stream stream) { TernaryOpType topt = get_ternary_op_type(a, b, c); set_ternary_op_output_data(a, b, c, out, topt); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), b = array::unsafe_weak_copy(b), c = array::unsafe_weak_copy(c), out = array::unsafe_weak_copy(out), op, topt]() mutable { switch (out.dtype()) { case bool_: ternary_op(a, b, c, out, op, topt); break; case uint8: ternary_op(a, b, c, out, op, topt); break; case uint16: ternary_op(a, b, c, out, op, topt); break; case uint32: ternary_op(a, b, c, out, op, topt); break; case uint64: ternary_op(a, b, c, out, op, topt); break; case int8: ternary_op(a, b, c, out, op, topt); break; case int16: ternary_op(a, b, c, out, op, topt); break; case int32: ternary_op(a, b, c, out, op, topt); break; case int64: ternary_op(a, b, c, out, op, topt); break; case float16: ternary_op( a, b, c, out, op, topt); break; case float32: ternary_op(a, b, c, out, op, topt); break; case float64: ternary_op(a, b, c, out, op, topt); break; case bfloat16: ternary_op( a, b, c, out, op, topt); break; case complex64: ternary_op( a, b, c, out, op, topt); break; } }); } } // namespace void Select::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); const auto& condition = inputs[0]; const auto& a = inputs[1]; const auto& b = inputs[2]; select_op(condition, a, b, out, detail::Select(), stream()); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/simd/accelerate_fp16_simd.h ================================================ #pragma once #include "mlx/backend/cpu/simd/base_simd.h" #if MLX_SIMD_LIBRARY_VERSION < 6 #include "mlx/backend/cpu/simd/neon_fp16_simd.h" #endif namespace mlx::core::simd { #if MLX_SIMD_LIBRARY_VERSION >= 6 constexpr int N = 8; template struct ScalarT { using v = _Float16; }; #endif template <> inline constexpr int max_size = N; #define SIMD_FP16_DEFAULT_UNARY(op) \ template <> \ inline Simd op(Simd v) { \ Simd in = v; \ return op(in); \ } SIMD_FP16_DEFAULT_UNARY(acos) SIMD_FP16_DEFAULT_UNARY(acosh) SIMD_FP16_DEFAULT_UNARY(asin) SIMD_FP16_DEFAULT_UNARY(asinh) SIMD_FP16_DEFAULT_UNARY(atan) SIMD_FP16_DEFAULT_UNARY(atanh) SIMD_FP16_DEFAULT_UNARY(cosh) SIMD_FP16_DEFAULT_UNARY(expm1) SIMD_FP16_DEFAULT_UNARY(log) SIMD_FP16_DEFAULT_UNARY(log2) SIMD_FP16_DEFAULT_UNARY(log10) SIMD_FP16_DEFAULT_UNARY(log1p) SIMD_FP16_DEFAULT_UNARY(sinh) SIMD_FP16_DEFAULT_UNARY(tan) SIMD_FP16_DEFAULT_UNARY(tanh) #define SIMD_FP16_DEFAULT_BINARY(op) \ template <> \ inline Simd op(Simd x, Simd y) { \ Simd a = x; \ Simd b = y; \ return op(a, b); \ } SIMD_FP16_DEFAULT_BINARY(atan2) SIMD_FP16_DEFAULT_BINARY(remainder) SIMD_FP16_DEFAULT_BINARY(pow) } // namespace mlx::core::simd ================================================ FILE: mlx/backend/cpu/simd/accelerate_simd.h ================================================ #pragma once #include #include #include #include #include #include #include "mlx/backend/cpu/simd/base_simd.h" // There seems to be a bug in simd/base_simd.h // __XROS_2_0 is not defined, the expression evaluates // to true instead of false setting the SIMD library // higher than it should be even on macOS < 15 #if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \ __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ __TV_OS_VERSION_MIN_REQUIRED >= 180000 #define MLX_SIMD_LIBRARY_VERSION 6 #else #define MLX_SIMD_LIBRARY_VERSION 5 #endif namespace mlx::core::simd { // Apple simd namespace namespace asd = ::simd; // This indirection is needed to remap certain types to ones that accelerate // SIMD can handle template struct ScalarT { using v = T; }; template struct ScalarT { using v = char; }; template struct ScalarT { using v = char; }; template struct ScalarT { using v = unsigned long; }; template struct ScalarT { using v = long; }; template struct Simd { static constexpr int size = N; using scalar_t = typename ScalarT::v; Simd() {} template Simd(Simd other) : value(asd::convert(other.value)) {} template Simd(U v) : value(v){}; Simd(Simd x, Simd y) { value = asd::make::packed_t>( x.value, y.value); }; T operator[](int idx) const { return reinterpret_cast(&value)[idx]; } T& operator[](int idx) { return reinterpret_cast(&value)[idx]; } typename asd::Vector::packed_t value; }; // Values chosen based on benchmarks on M3 Max // TODO: consider choosing these more optimally template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; #define SIMD_DEFAULT_UNARY(name, op) \ template \ Simd name(Simd v) { \ return op(v.value); \ } SIMD_DEFAULT_UNARY(abs, asd::abs) SIMD_DEFAULT_UNARY(floor, asd::floor) SIMD_DEFAULT_UNARY(acos, asd::acos) SIMD_DEFAULT_UNARY(acosh, asd::acosh) SIMD_DEFAULT_UNARY(asin, asd::asin) SIMD_DEFAULT_UNARY(asinh, asd::asinh) SIMD_DEFAULT_UNARY(atan, asd::atan) SIMD_DEFAULT_UNARY(atanh, asd::atanh) SIMD_DEFAULT_UNARY(ceil, asd::ceil) SIMD_DEFAULT_UNARY(cosh, asd::cosh) SIMD_DEFAULT_UNARY(expm1, asd::expm1) SIMD_DEFAULT_UNARY(log, asd::log) SIMD_DEFAULT_UNARY(log2, asd::log2) SIMD_DEFAULT_UNARY(log10, asd::log10) SIMD_DEFAULT_UNARY(log1p, asd::log1p) SIMD_DEFAULT_UNARY(rint, asd::rint) SIMD_DEFAULT_UNARY(sinh, asd::sinh) SIMD_DEFAULT_UNARY(sqrt, asd::sqrt) SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt) SIMD_DEFAULT_UNARY(recip, asd::recip) SIMD_DEFAULT_UNARY(tan, asd::tan) SIMD_DEFAULT_UNARY(tanh, asd::tanh) template Simd operator-(Simd v) { return -v.value; } template Simd operator~(Simd v) { return ~v.value; } template Simd isnan(Simd v) { return asd::convert(v.value != v.value); } // No simd_boolN in accelerate, use int8_t instead template Simd operator!(Simd v) { return asd::convert(!v.value); } #define SIMD_DEFAULT_BINARY(OP) \ template \ Simd operator OP(Simd x, U y) { \ return asd::convert::scalar_t>(x.value OP y); \ } \ template \ Simd operator OP(T1 x, Simd y) { \ return asd::convert::scalar_t>(x OP y.value); \ } \ template \ Simd operator OP(Simd x, Simd y) { \ return asd::convert::scalar_t>(x.value OP y.value); \ } SIMD_DEFAULT_BINARY(+) SIMD_DEFAULT_BINARY(-) SIMD_DEFAULT_BINARY(/) SIMD_DEFAULT_BINARY(*) SIMD_DEFAULT_BINARY(<<) SIMD_DEFAULT_BINARY(>>) SIMD_DEFAULT_BINARY(|) SIMD_DEFAULT_BINARY(^) SIMD_DEFAULT_BINARY(&) SIMD_DEFAULT_BINARY(&&) SIMD_DEFAULT_BINARY(||) #define SIMD_DEFAULT_COMPARISONS(OP) \ template \ Simd operator OP(Simd a, U b) { \ return asd::convert(a.value OP b); \ } \ template \ Simd operator OP(T a, Simd b) { \ return asd::convert(a OP b.value); \ } \ template \ Simd operator OP(Simd a, Simd b) { \ return asd::convert(a.value OP b.value); \ } SIMD_DEFAULT_COMPARISONS(>) SIMD_DEFAULT_COMPARISONS(<) SIMD_DEFAULT_COMPARISONS(>=) SIMD_DEFAULT_COMPARISONS(<=) SIMD_DEFAULT_COMPARISONS(==) SIMD_DEFAULT_COMPARISONS(!=) template Simd clz(Simd x) { auto a = *(uint32x4_t*)(&x); auto b = *((uint32x4_t*)(&x) + 1); a = vclzq_u32(a); b = vclzq_u32(b); return asd::make_uint8(a, b); } template Simd atan2(Simd a, Simd b) { return asd::atan2(a.value, b.value); } template Simd maximum(Simd a, Simd b) { auto out = Simd(asd::max(a.value, b.value)); if constexpr (!std::is_integral_v) { out = select(isnan(b), b, select(isnan(a), a, out)); } return out; } template Simd minimum(Simd a, Simd b) { auto out = Simd(asd::min(a.value, b.value)); if constexpr (!std::is_integral_v) { out = select(isnan(b), b, select(isnan(a), a, out)); } return out; } template Simd remainder(Simd a, Simd b) { Simd r; if constexpr (!std::is_integral_v) { r = asd::remainder(a.value, b.value); } else { r = a - b * (a / b); } if constexpr (std::is_signed_v) { auto mask = r != 0 && (r < 0 != b < 0); r = select(mask, r + b, r); } return r; } template Simd select(Simd mask, Simd x, Simd y) { static_assert(std::is_same_v); if constexpr (sizeof(T1) == 1) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else if constexpr (sizeof(T1) == 2) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else if constexpr (sizeof(T1) == 4) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } } template Simd pow(Simd base, Simd exp) { if constexpr (!std::is_integral_v) { return asd::pow(base.value, exp.value); } else { Simd res = 1; // Raising an integer to a negative power is undefined if (any(exp < 0)) { return 0; } while (any(exp > 0)) { res = select((exp & 1) != 0, res * base, res); base = select(exp > 0, base * base, base); exp = exp >> 1; } return res; } } template Simd clamp(Simd v, Simd min, Simd max) { return asd::clamp(v.value, min.value, max.value); } template Simd fma(Simd x, Simd y, U z) { return asd::muladd(x.value, y.value, Simd(z).value); } // Reductions template bool all(Simd x) { return asd::all(x.value); } template bool any(Simd x) { return asd::any(x.value); } template T sum(Simd x) { return asd::reduce_add(x.value); } template T max(Simd x) { return asd::reduce_max(x.value); } template T min(Simd x) { return asd::reduce_min(x.value); } template T prod(Simd x) { auto ptr = (T*)&x; auto lhs = load(ptr); auto rhs = load(ptr + N / 2); return prod(lhs * rhs); } } // namespace mlx::core::simd #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include "mlx/backend/cpu/simd/accelerate_fp16_simd.h" #endif ================================================ FILE: mlx/backend/cpu/simd/base_simd.h ================================================ #pragma once // Required for using M_LN2 in MSVC. #define _USE_MATH_DEFINES #include #include #include #include #include #ifdef _MSC_VER #include // For _BitScanReverse #endif namespace mlx::core::simd { template struct Simd; template static constexpr int max_size = 1; template struct Simd { static constexpr int size = 1; T value; Simd() {} template Simd(Simd v) : value(v.value) {} template Simd(U v) : value(v) {} T operator[](int) const { return value; } T& operator[](int) { return value; } }; template Simd load(const T* x) { return *(Simd*)x; } template void store(T* dst, Simd x) { // Maintain invariant that bool is either 0 or 1 as // simd comparison ops set all bits in the result to 1 if constexpr (std::is_same_v && N > 1) { x = x & 1; } *(Simd*)dst = x; } template constexpr bool is_complex = false; template constexpr bool is_complex().real())>> = true; template Simd rint(Simd in) { if constexpr (is_complex) { return Simd{ T{std::rint(in.value.real()), std::rint(in.value.imag())}}; } else { return Simd{std::rint(in.value)}; } } template Simd rsqrt(Simd in) { return T(1.0) / sqrt(in); } template Simd recip(Simd in) { return T(1.0) / in; } #define DEFAULT_UNARY(name, op) \ template \ Simd name(Simd in) { \ return op(in.value); \ } DEFAULT_UNARY(operator-, std::negate{}) DEFAULT_UNARY(operator!, std::logical_not{}) DEFAULT_UNARY(abs, std::abs) DEFAULT_UNARY(acos, std::acos) DEFAULT_UNARY(acosh, std::acosh) DEFAULT_UNARY(asin, std::asin) DEFAULT_UNARY(asinh, std::asinh) DEFAULT_UNARY(atan, std::atan) DEFAULT_UNARY(atanh, std::atanh) DEFAULT_UNARY(ceil, std::ceil) DEFAULT_UNARY(conj, std::conj) DEFAULT_UNARY(cosh, std::cosh) DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) template Simd log1p(Simd in) { if constexpr (is_complex) { auto x = in.value.real(); auto y = in.value.imag(); auto zabs = std::abs(in.value); auto theta = std::atan2(y, x + 1); if (zabs < 0.5) { auto r = x * (2 + x) + y * y; if (r == 0) { // handle underflow return Simd{T{x, theta}}; } return Simd{T{((decltype(x))(0.5)) * std::log1p(r), theta}}; } else { auto z0 = std::hypot(x + 1, y); return Simd{T{std::log(z0), theta}}; } } else { return Simd{std::log1p(in.value)}; } } template Simd log2(Simd in) { if constexpr (is_complex) { auto out = std::log(in.value); auto scale = decltype(out.real())(M_LN2); return Simd{T{out.real() / scale, out.imag() / scale}}; } else { return Simd{std::log2(in.value)}; } } template Simd operator~(Simd in) { return ~in.value; } template auto real(Simd in) -> Simd { return std::real(in.value); } template auto imag(Simd in) -> Simd { return std::imag(in.value); } template Simd isnan(Simd in) { return std::isnan(in.value); } #define DEFAULT_BINARY(OP) \ template \ auto operator OP(Simd a, Simd b) \ ->Simd { \ return a.value OP b.value; \ } \ template \ auto operator OP(T1 a, Simd b)->Simd { \ return a OP b.value; \ } \ template \ auto operator OP(Simd a, T2 b)->Simd { \ return a.value OP b; \ } DEFAULT_BINARY(+) DEFAULT_BINARY(-) DEFAULT_BINARY(*) DEFAULT_BINARY(/) DEFAULT_BINARY(<<) DEFAULT_BINARY(>>) DEFAULT_BINARY(|) DEFAULT_BINARY(^) DEFAULT_BINARY(&) DEFAULT_BINARY(&&) DEFAULT_BINARY(||) template Simd clz(Simd x_) { #ifdef _MSC_VER // MSVC doesn't have __builtin_clz, use _BitScanReverse instead unsigned long index; if (_BitScanReverse(&index, static_cast(x_.value))) { return static_cast(31 - index); } return static_cast(32); // All zeros case #else return __builtin_clz(x_.value); #endif } template Simd remainder(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; T r; if constexpr (std::is_integral_v) { r = a % b; } else { r = std::remainder(a, b); } if constexpr (std::is_signed_v) { if (r != 0 && (r < 0 != b < 0)) { r += b; } } return r; } template Simd maximum(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; if constexpr (!std::is_integral_v) { if (std::isnan(a)) { return a; } } return (a > b) ? a : b; } template Simd minimum(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; if constexpr (!std::is_integral_v) { if (std::isnan(a)) { return a; } } return (a < b) ? a : b; } template Simd pow(Simd a, Simd b) { T base = a.value; T exp = b.value; if constexpr (!std::is_integral_v) { return std::pow(base, exp); } else { T res = 1; while (exp) { if (exp & 1) { res *= base; } exp >>= 1; base *= base; } return res; } } template Simd atan2(Simd a, Simd b) { return std::atan2(a.value, b.value); } #define DEFAULT_COMPARISONS(OP) \ template \ Simd operator OP(Simd a, Simd b) { \ return a.value OP b.value; \ } \ template \ Simd operator OP(T1 a, Simd b) { \ return a OP b.value; \ } \ template \ Simd operator OP(Simd a, T2 b) { \ return a.value OP b; \ } DEFAULT_COMPARISONS(>) DEFAULT_COMPARISONS(<) DEFAULT_COMPARISONS(>=) DEFAULT_COMPARISONS(<=) DEFAULT_COMPARISONS(==) DEFAULT_COMPARISONS(!=) template Simd select(Simd mask, Simd x, Simd y) { return mask.value ? x.value : y.value; } template Simd clamp(Simd v, Simd min, Simd max) { return std::clamp(v.value, min.value, max.value); } template Simd fma(Simd x, Simd y, U z) { return std::fma(x.value, y.value, Simd(z).value); } // Reductions #define DEFAULT_REDUCTION(name, type) \ template \ type name(Simd x) { \ return x.value; \ } DEFAULT_REDUCTION(max, T) DEFAULT_REDUCTION(min, T) DEFAULT_REDUCTION(sum, T) DEFAULT_REDUCTION(prod, T) DEFAULT_REDUCTION(any, bool) DEFAULT_REDUCTION(all, bool) } // namespace mlx::core::simd ================================================ FILE: mlx/backend/cpu/simd/math.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/cpu/simd/type.h" namespace mlx::core::simd { constexpr float inf = std::numeric_limits::infinity(); /** * Compute exp(x) in an optimizer friendly way as follows: * * First change the problem to computing 2**y where y = x / ln(2). * * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part * `ipart` and y2 is fractional part. For the integer part we perform bit * shifting and for the fractional part we use a polynomial approximation. * * The algorithm and constants of the polynomial taken from * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them * from Cephes math library. * * Note: The implementation below is a general fast exp. There could be faster * implementations for numbers strictly < 0. */ template Simd exp(Simd in) { if constexpr (is_complex) { return Simd{std::exp(in.value)}; } else { Simd x_init = in; auto x = x_init * 1.442695f; // multiply with log_2(e) Simd ipart, fpart; ipart = floor(x + 0.5); fpart = x - ipart; x = 1.535336188319500e-4f; x = fma(x, fpart, 1.339887440266574e-3f); x = fma(x, fpart, 9.618437357674640e-3f); x = fma(x, fpart, 5.550332471162809e-2f); x = fma(x, fpart, 2.402264791363012e-1f); x = fma(x, fpart, 6.931472028550421e-1f); x = fma(x, fpart, 1.000000000000000f); // generate 2**ipart in the floating point representation using integer // bitshifting Simd epart = (Simd(ipart) + 127) << 23; // Deal with NaN and Inf auto result = select(isnan(x_init), x_init, (*(Simd*)&epart) * x); result = select(x_init > 88.0f, Simd(inf), result); result = select(x_init < -88.0f, Simd(0), result); return Simd(result); } } /* Implementation from: * https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357 * which originally came from the Cephes math library. */ template Simd sincos(Simd in) { auto sign_mask_sin = in < 0; in = abs(in); Simd x = in; // scale by 4/Pi auto y = x * 1.27323954473516f; // store the integer part of y in mm0 Simd emm2 = y; // j=(j+1) & (~1) (see the cephes sources) emm2 = emm2 + 1; emm2 = emm2 & ~1; y = emm2; // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 // and another one for Pi/4(-0.78515625f), x); x = fma(y, Simd(-2.4187564849853515625e-4f), x); x = fma(y, Simd(-3.77489497744594108e-8f), x); sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); auto sign_mask_cos = ((emm2 - 2) & 4) != 0; // Evaluate the first polynom (0 <= x <= Pi/4) in y1, // and the second polynom (Pi/4 <= x <= 0) in y2 auto z = x * x; auto y1 = fma(z, Simd(2.443315711809948e-5f), -1.388731625493765e-3f); auto y2 = fma(z, Simd(-1.9515295891e-4f), 8.3321608736e-3f); y1 = fma(y1, z, 4.166664568298827e-2f); y2 = fma(y2, z, -1.6666654611e-1f); y1 = y1 * z; y2 = y2 * z; y1 = y1 * z; y2 = fma(x, y2, x); y1 = fma(z, Simd(-0.5f), y1); y1 = y1 + 1.0f; if constexpr (Sine) { auto ys = select(poly_mask, y1, y2); return select(sign_mask_sin, -ys, ys); } else { auto yc = select(poly_mask, y2, y1); return select(sign_mask_cos, yc, -yc); } } template Simd sin(Simd x) { if constexpr (is_complex) { return std::sin(x.value); } else { return sincos(x); } } template Simd cos(Simd x) { if constexpr (is_complex) { return std::cos(x.value); } else { return sincos(x); } } template Simd erf(Simd x) { // https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175 Simd v = x; auto t = recip(fma(Simd(0.3275911f), abs(v), 1.0f)); auto r = fma(Simd(1.061405429f), t, -1.453152027f); r = fma(r, t, 1.421413741f); r = fma(r, t, -0.284496736f); r = fma(r, t, 0.254829592f); auto e = -exp(-v * v); auto result = Simd(fma(e * t, r, 1.0f)); return select(x > 0, result, -result); } template Simd erfinv(Simd a_) { Simd a = a_; auto t = fma(a, 0.0f - a, 1.0f); t = log(t); auto lhs = [](auto t) { Simd p; p = 3.03697567e-10f; // 0x1.4deb44p-32 p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 }; auto rhs = [](auto t) { Simd p; p = 5.43877832e-9f; // 0x1.75c000p-28 p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 }; auto thresh = 6.125f; // Compute both branches and select if N > 1 if constexpr (N == 1) { if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793 return a * lhs(t); } else { // maximum ulp error = 2.35002 return a * rhs(t); } } else { return a * select(abs(t) > thresh, lhs(t), rhs(t)); } } } // namespace mlx::core::simd ================================================ FILE: mlx/backend/cpu/simd/neon_fp16_simd.h ================================================ #pragma once #include #include "mlx/backend/cpu/simd/base_simd.h" namespace mlx::core::simd { constexpr int N = 8; template <> struct Simd { static constexpr int size = N; using scalar_t = float16_t; Simd() {} template Simd(U v) : value(vdupq_n_f16(v)){}; Simd(float16x8_t v) : value(v){}; Simd(Simd other) { auto f32x4_a = *(float32x4_t*)(&other); auto f32x4_b = *((float32x4_t*)(&other) + 1); value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b); }; Simd(Simd other) { value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value)); }; operator Simd() { auto v = vcvtq_s16_f16(value); return load((int16_t*)&v); }; operator Simd() { float32x4x2_t v; v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value)); v.val[1] = vcvt_high_f32_f16(value); return load((float*)&v); } float16_t operator[](int idx) const { return reinterpret_cast(&value)[idx]; } float16_t& operator[](int idx) { return reinterpret_cast(&value)[idx]; } float16x8_t value; }; #define DEFINE_NEON_UNARY_OP(name, op) \ inline Simd name(Simd a) { \ return Simd{op(a.value)}; \ } DEFINE_NEON_UNARY_OP(abs, vabsq_f16) DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16) DEFINE_NEON_UNARY_OP(floor, vrndmq_f16) DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16) DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16) DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16) DEFINE_NEON_UNARY_OP(rint, vrndnq_f16) #define DEFINE_NEON_BINARY_OP(name, op) \ inline Simd name(Simd a, Simd b) { \ return op(a.value, b.value); \ } \ template \ Simd name(Simd a, T b) { \ return op(a.value, Simd(b).value); \ } \ template \ Simd name(T a, Simd b) { \ return op(Simd(a).value, b.value); \ } inline Simd operator!(Simd v) { auto out = vceqzq_f16(v.value); return Simd(*(uint16_t*)&out); } inline Simd operator-(Simd v) { return vnegq_f16(v.value); } DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16) DEFINE_NEON_BINARY_OP(minimum, vminq_f16) DEFINE_NEON_BINARY_OP(operator+, vaddq_f16) DEFINE_NEON_BINARY_OP(operator-, vsubq_f16) DEFINE_NEON_BINARY_OP(operator*, vmulq_f16) DEFINE_NEON_BINARY_OP(operator/, vdivq_f16) #define DEFINE_NEON_COMPARISON(Op, op) \ template \ Simd operator Op(Simd a, T b) { \ auto out = op(a.value, Simd(b).value); \ return Simd(*(uint16_t*)(&out)); \ } \ template \ Simd operator Op(T a, Simd b) { \ auto out = op(Simd(a).value, b.value); \ return Simd(*(uint16_t*)(&out)); \ } \ inline Simd operator Op( \ Simd a, Simd b) { \ auto out = op(a.value, b.value); \ return Simd(*(uint16_t*)(&out)); \ } DEFINE_NEON_COMPARISON(==, vceqq_f16) DEFINE_NEON_COMPARISON(>=, vcgeq_f16) DEFINE_NEON_COMPARISON(<=, vcleq_f16) DEFINE_NEON_COMPARISON(>, vcgtq_f16) DEFINE_NEON_COMPARISON(<, vcltq_f16) template Simd operator!=(Simd a, T b) { return !(a == b); } template Simd operator!=(T a, Simd b) { return !(a == b); } inline Simd operator!=(Simd a, Simd b) { return !(a == b); } inline Simd operator||( Simd a, Simd b) { return Simd((a != 0) || (b != 0)); } template Simd operator||(Simd a, T b) { return Simd((a != 0) || (b != 0)); } template Simd operator||(T a, Simd b) { return Simd((a != 0) || (b != 0)); } inline Simd operator&&( Simd a, Simd b) { return Simd((a != 0) && (b != 0)); } template Simd operator&&(Simd a, T b) { return Simd((a != 0) && (b != 0)); } template Simd operator&&(T a, Simd b) { return Simd((a != 0) && (b != 0)); } template <> inline Simd isnan(Simd v) { return v != v; } template <> inline Simd clamp(Simd v, Simd min, Simd max) { return minimum(maximum(v, min), max); } template Simd fma(Simd x, Simd y, T z) { return vfmaq_f16(x.value, y.value, Simd(z).value); } template Simd select(Simd mask, Simd x, Simd y) { return vbslq_f16(Simd(mask).value, x.value, y.value); } // Reductions inline float16_t max(Simd x) { float16x4_t y; y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value)); y = vpmax_f16(y, y); y = vpmax_f16(y, y); return vget_lane_f16(y, 0); } inline float16_t min(Simd x) { float16x4_t y; y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value)); y = vpmin_f16(y, y); y = vpmin_f16(y, y); return vget_lane_f16(y, 0); } inline float16_t sum(Simd x) { float16x4_t y; y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value)); y = vpadd_f16(y, y); y = vpadd_f16(y, y); return vget_lane_f16(y, 0); } inline float16_t prod(Simd x) { auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value)); auto out = hx[0]; hx[0] *= hx[1]; hx[0] *= hx[2]; hx[0] *= hx[3]; return hx[0]; } } // namespace mlx::core::simd ================================================ FILE: mlx/backend/cpu/simd/simd.h ================================================ #pragma once #include "mlx/backend/cpu/simd/math.h" #include "mlx/backend/cpu/simd/type.h" ================================================ FILE: mlx/backend/cpu/simd/type.h ================================================ #pragma once #include "mlx/backend/cpu/simd/base_simd.h" #ifdef MLX_USE_ACCELERATE #if defined(__x86_64__) // the accelerate_simd implementation require neon -- use base implementation #else #include "mlx/backend/cpu/simd/accelerate_simd.h" #endif #endif ================================================ FILE: mlx/backend/cpu/slicing.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { std::tuple prepare_slice( const array& in, const Shape& start_indices, const Shape& strides); void shared_buffer_slice( const array& in, const Strides& out_strides, size_t data_offset, size_t data_size, array& out); } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/softmax.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" #include "mlx/types/limits.h" namespace mlx::core { namespace { using namespace mlx::core::simd; template void softmax(const array& in, array& out, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(in); encoder.set_output_array(out); const T* in_ptr = in.data(); T* out_ptr = out.data(); int M = in.shape().back(); int L = in.data_size() / M; encoder.dispatch([in_ptr, out_ptr, M, L]() mutable { constexpr bool same_t = std::is_same_v; constexpr int N = std::min(max_size, max_size); const T* current_in_ptr; T* current_out_ptr; for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { // Find the maximum current_in_ptr = in_ptr; Simd vmaximum(-numeric_limits::infinity()); size_t s = M; while (s >= N) { Simd vals = load(current_in_ptr); vmaximum = maximum(vals, vmaximum); current_in_ptr += N; s -= N; } AccT maximum = max(vmaximum); while (s-- > 0) { maximum = std::max(maximum, static_cast(*current_in_ptr)); current_in_ptr++; } // Compute the normalizer and the exponentials Simd vnormalizer(0.0); current_out_ptr = out_ptr; current_in_ptr = in_ptr; s = M; while (s >= N) { Simd vexp = load(current_in_ptr); vexp = exp(vexp - maximum); if constexpr (same_t) { store(current_out_ptr, vexp); } vnormalizer = vnormalizer + vexp; current_in_ptr += N; current_out_ptr += N; s -= N; } AccT normalizer = sum(vnormalizer); while (s-- > 0) { AccT _exp = std::exp(*current_in_ptr - maximum); if constexpr (same_t) { *current_out_ptr = _exp; } normalizer += _exp; current_in_ptr++; current_out_ptr++; } normalizer = 1 / normalizer; // Normalize current_out_ptr = out_ptr; current_in_ptr = in_ptr; s = M; while (s >= N) { if constexpr (same_t) { store( current_out_ptr, Simd(load(current_out_ptr) * normalizer)); } else { Simd vexp = load(current_in_ptr); vexp = exp(vexp - maximum) * normalizer; store(current_out_ptr, Simd(vexp)); current_in_ptr += N; } current_out_ptr += N; s -= N; } while (s-- > 0) { if constexpr (same_t) { *current_out_ptr *= normalizer; } else { AccT _exp = std::exp(*current_in_ptr - maximum); *current_out_ptr = static_cast(_exp * normalizer); current_in_ptr++; } current_out_ptr++; } } }); } } // namespace void Softmax::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // Make sure that the last dimension is contiguous auto set_output = [s = stream(), &out](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_cpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; auto in = set_output(inputs[0]); switch (in.dtype()) { case float32: softmax(in, out, stream()); break; case float16: if (precise_) { softmax(in, out, stream()); } else { softmax(in, out, stream()); } break; case bfloat16: if (precise_) { softmax(in, out, stream()); } else { softmax(in, out, stream()); } break; case float64: softmax(in, out, stream()); break; default: throw std::runtime_error( "[softmax] Only defined for floating point types."); break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/sort.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template inline constexpr bool is_floating_v = std::is_floating_point_v || std::is_same_v || std::is_same_v; // NaN-aware comparator that places NaNs at the end template bool nan_aware_less(T a, T b) { if constexpr (is_floating_v || std::is_same_v) { if (std::isnan(a)) return false; if (std::isnan(b)) return true; } return a < b; } template struct StridedIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = int32_t; using value_type = T; using reference = value_type&; using pointer = value_type*; // Constructors StridedIterator() = default; explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) : stride_(stride), ptr_(ptr + offset * stride) {} explicit StridedIterator(array& arr, int axis, difference_type offset = 0) : StridedIterator(arr.data(), arr.strides()[axis], offset) {} // Accessors reference operator*() const { return ptr_[0]; } reference operator[](difference_type idx) const { return ptr_[idx * stride_]; } // Comparisons bool operator==(const StridedIterator& other) const { return ptr_ == other.ptr_ && stride_ == other.stride_; } bool operator!=(const StridedIterator& other) const { return ptr_ != other.ptr_; } bool operator<(const StridedIterator& other) const { return ptr_ < other.ptr_; } bool operator>(const StridedIterator& other) const { return ptr_ > other.ptr_; } bool operator<=(const StridedIterator& other) const { return ptr_ <= other.ptr_; } bool operator>=(const StridedIterator& other) const { return ptr_ >= other.ptr_; } difference_type operator-(const StridedIterator& other) const { return (ptr_ - other.ptr_) / stride_; } // Moving StridedIterator& operator++() { ptr_ += stride_; return *this; } StridedIterator& operator--() { ptr_ -= stride_; return *this; } StridedIterator& operator+=(difference_type diff) { ptr_ += diff * stride_; return *this; } StridedIterator& operator-=(difference_type diff) { ptr_ -= diff * stride_; return *this; } StridedIterator operator+(difference_type diff) { return StridedIterator(ptr_, stride_, diff); } StridedIterator operator-(difference_type diff) { return StridedIterator(ptr_, stride_, -diff); } private: int64_t stride_; T* ptr_; }; template void sort(array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; size_t in_size = out.size(); size_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); auto axis_stride = out.strides()[axis]; auto axis_size = out.shape(axis); // Perform sorting in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); for (int i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; StridedIterator st(data_ptr, axis_stride, 0); StridedIterator ed(data_ptr, axis_stride, axis_size); std::stable_sort(st, ed, nan_aware_less); src_it.step(); } } template void argsort(const array& in, array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); auto in_remaining_strides = in.strides(); in_remaining_strides.erase(in_remaining_strides.begin() + axis); auto out_remaining_shape = out.shape(); out_remaining_shape.erase(out_remaining_shape.begin() + axis); auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); auto in_stride = in.strides()[axis]; auto out_stride = out.strides()[axis]; auto axis_size = in.shape(axis); // Perform sorting ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); auto in_ptr = in.data(); auto out_ptr = out.data(); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; in_it.step(); out_it.step(); StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals StridedIterator st(idx_ptr, out_stride, 0); StridedIterator ed(idx_ptr, out_stride, axis_size); std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; // Handle NaNs (place them at the end) if constexpr (is_floating_v) { if (std::isnan(v1)) return false; if (std::isnan(v2)) return true; } return v1 < v2 || (v1 == v2 && a < b); }); } } template void partition(array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; size_t in_size = out.size(); size_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); auto axis_stride = out.strides()[axis]; int axis_size = out.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); for (int i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; src_it.step(); StridedIterator st(data_ptr, axis_stride, 0); StridedIterator md(data_ptr, axis_stride, kth); StridedIterator ed(data_ptr, axis_stride, axis_size); std::nth_element(st, md, ed, nan_aware_less); } } template void argpartition(const array& in, array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); auto in_remaining_strides = in.strides(); in_remaining_strides.erase(in_remaining_strides.begin() + axis); auto out_remaining_shape = out.shape(); out_remaining_shape.erase(out_remaining_shape.begin() + axis); auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); auto in_stride = in.strides()[axis]; auto out_stride = out.strides()[axis]; auto axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); auto in_ptr = in.data(); auto out_ptr = out.data(); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; in_it.step(); out_it.step(); StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals StridedIterator st(idx_ptr, out_stride, 0); StridedIterator md(idx_ptr, out_stride, kth); StridedIterator ed(idx_ptr, out_stride, axis_size); std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; // Handle NaNs (place them at the end) if constexpr (is_floating_v) { if (std::isnan(v1)) return false; if (std::isnan(v2)) return true; } return v1 < v2 || (v1 == v2 && a < b); }); } } } // namespace void ArgSort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; // Allocate output out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_input_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable { switch (in.dtype()) { case bool_: return argsort(in, out, axis_); case uint8: return argsort(in, out, axis_); case uint16: return argsort(in, out, axis_); case uint32: return argsort(in, out, axis_); case uint64: return argsort(in, out, axis_); case int8: return argsort(in, out, axis_); case int16: return argsort(in, out, axis_); case int32: return argsort(in, out, axis_); case int64: return argsort(in, out, axis_); case float32: return argsort(in, out, axis_); case float64: return argsort(in, out, axis_); case float16: return argsort(in, out, axis_); case bfloat16: return argsort(in, out, axis_); case complex64: return argsort(in, out, axis_); } }); } void Sort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; int axis = axis_; if (axis < 0) { axis += in.ndim(); } // Copy input to output CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0) ? CopyType::Vector : CopyType::General; copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable { dispatch_all_types(out.dtype(), [&](auto type_tag) { sort(out, axis); }); }); } void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; // Allocate output out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_input_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), axis_ = axis_, kth_ = kth_]() mutable { switch (in.dtype()) { case bool_: return argpartition(in, out, axis_, kth_); case uint8: return argpartition(in, out, axis_, kth_); case uint16: return argpartition(in, out, axis_, kth_); case uint32: return argpartition(in, out, axis_, kth_); case uint64: return argpartition(in, out, axis_, kth_); case int8: return argpartition(in, out, axis_, kth_); case int16: return argpartition(in, out, axis_, kth_); case int32: return argpartition(in, out, axis_, kth_); case int64: return argpartition(in, out, axis_, kth_); case float32: return argpartition(in, out, axis_, kth_); case float64: return argpartition(in, out, axis_, kth_); case float16: return argpartition(in, out, axis_, kth_); case bfloat16: return argpartition(in, out, axis_, kth_); case complex64: return argpartition(in, out, axis_, kth_); } }); } void Partition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; // Copy input to output CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) ? CopyType::Vector : CopyType::General; copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); encoder.dispatch([out = array::unsafe_weak_copy(out), axis_ = axis_, kth_ = kth_]() mutable { switch (out.dtype()) { case bool_: return partition(out, axis_, kth_); case uint8: return partition(out, axis_, kth_); case uint16: return partition(out, axis_, kth_); case uint32: return partition(out, axis_, kth_); case uint64: return partition(out, axis_, kth_); case int8: return partition(out, axis_, kth_); case int16: return partition(out, axis_, kth_); case int32: return partition(out, axis_, kth_); case int64: return partition(out, axis_, kth_); case float32: return partition(out, axis_, kth_); case float64: return partition(out, axis_, kth_); case float16: return partition(out, axis_, kth_); case bfloat16: return partition(out, axis_, kth_); case complex64: return partition(out, axis_, kth_); } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/svd.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template struct SVDWork {}; template struct SVDWork< T, typename std::enable_if::value>::type> { using R = T; int N; int M; int K; int lda; int ldu; int ldvt; char jobz; std::vector buffers; int lwork; SVDWork(int N, int M, int K, char jobz) : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) { T workspace_dimension = 0; // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K)); int lwork_query = -1; int info; // Compute workspace size. gesdd( /* jobz = */ &jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ nullptr, /* lda = */ &lda, /* s = */ nullptr, /* u = */ nullptr, /* ldu = */ &ldu, /* vt = */ nullptr, /* ldvt = */ &ldvt, /* work = */ &workspace_dimension, /* lwork = */ &lwork_query, /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; throw std::runtime_error(ss.str()); } lwork = workspace_dimension; buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); } void run(T* a, R* s, T* u, T* vt) { int info; gesdd( /* jobz = */ &jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ a, /* lda = */ &lda, /* s = */ s, // According to the identity above, lapack will write Vᵀᵀ as U. /* u = */ u, /* ldu = */ &ldu, // According to the identity above, lapack will write Uᵀ as Vᵀ. /* vt = */ vt, /* ldvt = */ &ldvt, /* work = */ static_cast(buffers[1].buffer.raw_ptr()), /* lwork = */ &lwork, /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "svd_impl: sgesvdx_ failed with code " << info; throw std::runtime_error(ss.str()); } } }; template <> struct SVDWork> { using T = std::complex; using R = float; int N; int M; int K; int lda; int ldu; int ldvt; char jobz; std::vector buffers; int lwork; SVDWork(int N, int M, int K, char jobz) : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) { T workspace_dimension = 0; // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K)); const int lrwork = jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K); buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork)); int lwork_query = -1; int work_query = -1; int info; // Compute workspace size. gesdd( /* jobz = */ &jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ nullptr, /* lda = */ &lda, /* s = */ nullptr, /* u = */ nullptr, /* ldu = */ &ldu, /* vt = */ nullptr, /* ldvt = */ &ldvt, /* work = */ &workspace_dimension, /* lwork = */ &lwork_query, /* rwork = */ static_cast(buffers[1].buffer.raw_ptr()), /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; throw std::runtime_error(ss.str()); } lwork = workspace_dimension.real(); buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); } void run(T* a, R* s, T* u, T* vt) { int info; gesdd( /* jobz = */ &jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ a, /* lda = */ &lda, /* s = */ s, // According to the identity above, lapack will write Vᵀᵀ as U. /* u = */ u, /* ldu = */ &ldu, // According to the identity above, lapack will write Uᵀ as Vᵀ. /* vt = */ vt, /* ldvt = */ &ldvt, /* work = */ static_cast(buffers[2].buffer.raw_ptr()), /* lwork = */ &lwork, /* rwork = */ static_cast(buffers[1].buffer.raw_ptr()), /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; ss << "svd_impl: sgesvdx_ failed with code " << info; throw std::runtime_error(ss.str()); } } }; template void svd_impl( const array& a, std::vector& outputs, bool compute_uv, Stream stream) { // Lapack uses the column-major convention. To avoid having to transpose // the input and then transpose the outputs, we swap the indices/sizes of the // matrices and take advantage of the following identity (see // https://math.stackexchange.com/a/30077) // A = UΣVᵀ // Aᵀ = VΣUᵀ // As a result some of the indices/sizes are swapped as noted above. // Rows and cols of the original matrix in row-major order. const int M = a.shape(-2); const int N = a.shape(-1); const int K = std::min(M, N); using R = typename SVDWork::R; size_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. array in(a.shape(), a.dtype(), nullptr, {}); copy_cpu( a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, stream); // Allocate outputs. auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); auto in_ptr = in.data(); T* u_ptr; R* s_ptr; T* vt_ptr; if (compute_uv) { array& u = outputs[0]; array& s = outputs[1]; array& vt = outputs[2]; u.set_data(allocator::malloc(u.nbytes())); s.set_data(allocator::malloc(s.nbytes())); vt.set_data(allocator::malloc(vt.nbytes())); encoder.set_output_array(u); encoder.set_output_array(s); encoder.set_output_array(vt); s_ptr = s.data(); u_ptr = u.data(); vt_ptr = vt.data(); } else { array& s = outputs[0]; s.set_data(allocator::malloc(s.nbytes())); encoder.set_output_array(s); s_ptr = s.data(); u_ptr = nullptr; vt_ptr = nullptr; } encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() { auto jobz = (u_ptr) ? 'A' : 'N'; SVDWork svd_work(N, M, K, jobz); // Loop over matrices. for (int i = 0; i < num_matrices; i++) { svd_work.run( in_ptr + M * N * i, s_ptr + K * i, vt_ptr ? vt_ptr + N * N * i : nullptr, u_ptr ? u_ptr + M * M * i : nullptr); } }); encoder.add_temporary(in); } void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { switch (inputs[0].dtype()) { case float32: svd_impl(inputs[0], outputs, compute_uv_, stream()); break; case float64: svd_impl(inputs[0], outputs, compute_uv_, stream()); break; case complex64: svd_impl>(inputs[0], outputs, compute_uv_, stream()); break; default: throw std::runtime_error( "[SVD::eval_cpu] only supports float32, float64, or complex64."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/ternary.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/common/ternary.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" namespace mlx::core { template void ternary_op_dims( const T1* a, const T2* b, const T3* c, U* out, Op op, const Shape& shape, const Strides& a_strides, const Strides& b_strides, const Strides& c_strides, const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; auto stride_c = c_strides[axis]; auto stride_out = out_strides[axis]; auto N = shape[axis]; for (int i = 0; i < N; i++) { if constexpr (D > 1) { ternary_op_dims( a, b, c, out, op, shape, a_strides, b_strides, c_strides, out_strides, axis + 1); } else { *out = op(*a, *b, *c); } a += stride_a; b += stride_b; c += stride_c; out += stride_out; } } template void ternary_op_dispatch_dims( const T1* a_ptr, const T2* b_ptr, const T3* c_ptr, U* out_ptr, Op op, size_t size, Shape& shape, std::vector& strides) { const auto& a_strides = strides[0]; const auto& b_strides = strides[1]; const auto& c_strides = strides[2]; const auto& out_strides = strides[3]; int ndim = shape.size(); switch (ndim) { case 1: ternary_op_dims( a_ptr, b_ptr, c_ptr, out_ptr, op, shape, a_strides, b_strides, c_strides, out_strides, 0); return; case 2: ternary_op_dims( a_ptr, b_ptr, c_ptr, out_ptr, op, shape, a_strides, b_strides, c_strides, out_strides, 0); return; } ContiguousIterator a_it(shape, a_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2); ContiguousIterator c_it(shape, c_strides, ndim - 2); auto stride = out_strides[ndim - 3]; for (size_t elem = 0; elem < size; elem += stride) { ternary_op_dims( a_ptr + a_it.loc, b_ptr + b_it.loc, c_ptr + c_it.loc, out_ptr + elem, op, shape, a_strides, b_strides, c_strides, out_strides, ndim - 2); a_it.step(); b_it.step(); c_it.step(); } } template void ternary_op( const array& a, const array& b, const array& c, array& out, Op op, TernaryOpType topt) { const T1* a_ptr = a.data(); const T2* b_ptr = b.data(); const T3* c_ptr = c.data(); U* out_ptr = out.data(); if (topt == TernaryOpType::ScalarScalarScalar) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); } else if (topt == TernaryOpType::VectorVectorVector) { for (size_t i = 0; i < out.size(); ++i) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); a_ptr++; b_ptr++; c_ptr++; out_ptr++; } } else { auto [shape, strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); ternary_op_dispatch_dims( a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/threefry.cpp ================================================ // Copyright © 2023 Apple Inc. #include "mlx/backend/cpu/threefry.h" namespace mlx::core::random { std::pair threefry2x32_hash( const std::pair& key, std::pair count) { constexpr static uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA}; count.first += ks[0]; count.second += ks[1]; for (int i = 0; i < 5; ++i) { for (auto r : rotations[i % 2]) { count.first += count.second; count.second = (count.second << r) | (count.second >> (32 - r)); count.second ^= count.first; } count.first += ks[(i + 1) % 3]; count.second += ks[(i + 2) % 3] + i + 1; } return count; } } // namespace mlx::core::random ================================================ FILE: mlx/backend/cpu/threefry.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include namespace mlx::core::random { /** Applies the Threefry 2x32 hash function. * This code is based on the Jax counter-based and splittable PRNG * https://github.com/google/jax/blob/main/docs/jep/263-prng.md * * Original Threefry reference: * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf */ std::pair threefry2x32_hash( const std::pair& key, std::pair count); } // namespace mlx::core::random ================================================ FILE: mlx/backend/cpu/unary.cpp ================================================ // Copyright © 2024 Apple Inc. // Required for using M_LN2 in MSVC. #define _USE_MATH_DEFINES #include #include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary_ops.h" #include "mlx/primitives.h" namespace mlx::core { void Abs::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) { // No-op for unsigned types out.copy_shared_buffer(in); } else { unary_signed(in, out, detail::Abs(), stream()); } } void ArcCos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcCos(), stream()); } void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcCosh(), stream()); } void ArcSin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcSin(), stream()); } void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcSinh(), stream()); } void ArcTan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcTan(), stream()); } void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::ArcTanh(), stream()); } void BitwiseInvert::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_int(in, out, detail::BitwiseInvert(), stream()); } void Ceil::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Ceil(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); } } void Conjugate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); unary_complex(inputs[0], out, detail::Conjugate(), stream()); } void Cos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Cos(), stream()); } void Cosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Cosh(), stream()); } void Erf::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_real_fp(in, out, detail::Erf(), stream()); } void ErfInv::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_real_fp(in, out, detail::ErfInv(), stream()); } void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Exp(), stream()); } void Expm1::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Expm1(), stream()); } void Floor::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Floor(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); } } void Imag::eval_cpu(const std::vector& inputs, array& out) { unary_complex_to_float(inputs[0], out, detail::Imag(), stream()); } void Log::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; switch (base_) { case Base::e: unary_fp(in, out, detail::Log(), stream()); break; case Base::two: unary_fp(in, out, detail::Log2(), stream()); break; case Base::ten: unary_fp(in, out, detail::Log10(), stream()); break; } } void Log1p::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Log1p(), stream()); } void LogicalNot::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; unary(in, out, detail::LogicalNot(), stream()); } void Negative::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; unary(in, out, detail::Negative(), stream()); } void Real::eval_cpu(const std::vector& inputs, array& out) { unary_complex_to_float(inputs[0], out, detail::Real(), stream()); } void Round::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Round(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); } } void Sigmoid::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Sigmoid(), stream()); } void Sign::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() == bool_) { out.copy_shared_buffer(in); } else { unary(in, out, detail::Sign(), stream()); } } void Sin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Sin(), stream()); } void Sinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Sinh(), stream()); } void Square::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; unary(in, out, detail::Square(), stream()); } void Sqrt::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (recip_) { unary_fp(in, out, detail::Rsqrt(), stream()); } else { unary_fp(in, out, detail::Sqrt(), stream()); } } void Tan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Tan(), stream()); } void Tanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; unary_fp(in, out, detail::Tanh(), stream()); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/unary.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/utils.h" namespace mlx::core { template void unary_op(const T* a, U* out, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { out[i] = Op{}(*a); a += stride; } } template void unary_op(const array& a, array& out, Op) { const T* src = a.data(); U* dst = out.data(); auto ndim = a.ndim(); if (a.flags().contiguous) { auto size = a.data_size(); constexpr int N = std::min(simd::max_size, simd::max_size); while (size >= N) { simd::store(dst, simd::Simd(Op{}(simd::load(src)))); size -= N; src += N; dst += N; } while (size > 0) { *dst = Op{}(*src); size--; dst++; src++; } } else { size_t shape = ndim > 0 ? a.shape().back() : 1; size_t stride = ndim > 0 ? a.strides().back() : 1; if (ndim <= 1) { unary_op(src, dst, shape, stride); return; } auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); for (size_t elem = 0; elem < a.size(); elem += shape) { unary_op(src + it.loc, dst + elem, shape, stride); it.step(); } } } template void unary(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { switch (out.dtype()) { case bool_: unary_op(a, out, op); break; case uint8: unary_op(a, out, op); break; case uint16: unary_op(a, out, op); break; case uint32: unary_op(a, out, op); break; case uint64: unary_op(a, out, op); break; case int8: unary_op(a, out, op); break; case int16: unary_op(a, out, op); break; case int32: unary_op(a, out, op); break; case int64: unary_op(a, out, op); break; case float16: unary_op(a, out, op); break; case float32: unary_op(a, out, op); break; case float64: unary_op(a, out, op); break; case bfloat16: unary_op(a, out, op); break; case complex64: unary_op(a, out, op); break; } }); } template void unary_real_fp(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { switch (out.dtype()) { case bfloat16: unary_op(a, out, op); break; case float16: unary_op(a, out, op); break; case float32: unary_op(a, out, op); break; case float64: unary_op(a, out, op); break; default: std::ostringstream err; err << "[unary_real] Does not support " << out.dtype(); throw std::runtime_error(err.str()); } }); } template void unary_fp(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { switch (out.dtype()) { case bfloat16: unary_op(a, out, op); break; case float16: unary_op(a, out, op); break; case float32: unary_op(a, out, op); break; case float64: unary_op(a, out, op); break; case complex64: unary_op(a, out, op); break; default: std::ostringstream err; err << "[unary_fp] Does not support " << out.dtype(); throw std::runtime_error(err.str()); } }); } template void unary_signed(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { switch (out.dtype()) { case int8: unary_op(a, out, op); break; case int16: unary_op(a, out, op); break; case int32: unary_op(a, out, op); break; case int64: unary_op(a, out, op); break; case float16: unary_op(a, out, op); break; case float32: unary_op(a, out, op); break; case float64: unary_op(a, out, op); break; case bfloat16: unary_op(a, out, op); break; case complex64: unary_op(a, out, op); break; default: throw std::runtime_error("[Abs] Called on unsigned type"); } }); } template void unary_complex(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { unary_op(a, out, op); }); } template void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch( [a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { unary_op(a, out, op); }); } template void unary_int(const array& a, array& out, Op op, Stream stream) { set_unary_output_data(a, out); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); encoder.dispatch([a = array::unsafe_weak_copy(a), out = array::unsafe_weak_copy(out), op = op]() mutable { switch (out.dtype()) { case uint8: unary_op(a, out, op); break; case uint16: unary_op(a, out, op); break; case uint32: unary_op(a, out, op); break; case uint64: unary_op(a, out, op); break; case int8: unary_op(a, out, op); break; case int16: unary_op(a, out, op); break; case int32: unary_op(a, out, op); break; case int64: unary_op(a, out, op); break; default: std::ostringstream err; err << "[unary_int] Does not support " << out.dtype(); throw std::runtime_error(err.str()); } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cpu/unary_ops.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core::detail { using namespace mlx::core::simd; #define SINGLE() \ template \ T operator()(T x) { \ return (*this)(Simd(x)).value; \ } #define DEFAULT_OP(Op, op) \ struct Op { \ template \ Simd operator()(Simd x) { \ return simd::op(x); \ } \ SINGLE() \ }; DEFAULT_OP(Abs, abs) DEFAULT_OP(ArcCos, acos) DEFAULT_OP(ArcCosh, acosh) DEFAULT_OP(ArcSin, asin) DEFAULT_OP(ArcSinh, asinh) DEFAULT_OP(ArcTan, atan) DEFAULT_OP(ArcTanh, atanh) DEFAULT_OP(BitwiseInvert, operator~) DEFAULT_OP(Ceil, ceil) DEFAULT_OP(Conjugate, conj) DEFAULT_OP(Cos, cos) DEFAULT_OP(Cosh, cosh) DEFAULT_OP(Erf, erf) DEFAULT_OP(ErfInv, erfinv) DEFAULT_OP(Exp, exp) DEFAULT_OP(Expm1, expm1) DEFAULT_OP(Floor, floor); DEFAULT_OP(Log, log); DEFAULT_OP(Log2, log2); DEFAULT_OP(Log10, log10); DEFAULT_OP(Log1p, log1p); DEFAULT_OP(LogicalNot, operator!) DEFAULT_OP(Negative, operator-) DEFAULT_OP(Round, rint); DEFAULT_OP(Sin, sin) DEFAULT_OP(Sinh, sinh) DEFAULT_OP(Sqrt, sqrt) DEFAULT_OP(Rsqrt, rsqrt) DEFAULT_OP(Tan, tan) DEFAULT_OP(Tanh, tanh) struct Imag { template Simd operator()(Simd x) { return simd::imag(x); } SINGLE() }; struct Real { template Simd operator()(Simd x) { return simd::real(x); } SINGLE() }; struct Sigmoid { template Simd operator()(Simd x) { auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); return simd::select(x < Simd{0}, y, Simd{1} - y); } SINGLE() }; struct Sign { template Simd operator()(Simd x) { auto z = Simd{0}; auto o = Simd{1}; auto m = Simd{-1}; if constexpr (std::is_unsigned_v) { return simd::select(x == z, z, o); } else if constexpr (std::is_same_v) { return simd::select(x == z, x, Simd(x / simd::abs(x))); } else { return simd::select(x < z, m, simd::select(x > z, o, z)); } } SINGLE() }; struct Square { template Simd operator()(Simd x) { return x * x; } SINGLE() }; template Simd fp32_from_bits(Simd x) { return *(Simd*)(&x); } template Simd fp32_to_bits(Simd x) { return *(Simd*)(&x); } struct ToFP8 { template Simd operator()(Simd f) { uint32_t fp8_max = 543 << 21; auto denorm_mask = Simd(141 << 23); Simd f_bits; Simd f32 = f; f_bits = fp32_to_bits(f32); Simd result = 0u; auto sign = f_bits & 0x80000000; f_bits = f_bits ^ sign; auto f_bits_low = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); auto result_low = Simd(f_bits_low - denorm_mask); auto mant_odd = Simd((f_bits >> 20) & 1); auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF); f_bits_high = f_bits_high + Simd(mant_odd); auto result_high = Simd(f_bits_high >> 20); result = select(f_bits < (121 << 23), result_low, result_high); auto result_sat = Simd(0x7E); result = select(f_bits >= fp8_max, result_sat, result); return result | Simd(sign >> 24); } template uint8_t operator()(T x) { return (*this)(Simd(x)).value; } }; struct FromFP8 { template Simd operator()(Simd x) { auto v = Simd(x & 127) << 7; Simd out; if constexpr (simd::max_size >= N) { auto converted = *(Simd*)(&v); out = converted * 256.0; } else { for (int i = 0; i < N; ++i) { auto converted = *(float16_t*)(&v[i]); out[i] = converted * 256.0; } } auto sign = Simd(x & 128); return select(sign, -out, out); } float operator()(uint8_t x) { return (*this)(Simd(x)).value; } }; } // namespace mlx::core::detail ================================================ FILE: mlx/backend/cuda/CMakeLists.txt ================================================ # Filename rules in cuda backend: # # * Use .cu/.cuh if code contains device code, and .cpp/.h if not. # * Device-only code should be put in device/ subdir. # * Files in device/ subdir should not include files outside. target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fft.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) # fp4 is not available on < 12.8 if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0) target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/no_qqmm_impl.cpp) else() target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp) endif() if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) else() target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp) endif() # Embed kernel sources in binary for JIT compilation. file( GLOB MLX_JIT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) add_custom_command( OUTPUT gen/cuda_jit_sources.h COMMAND ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h) add_dependencies(mlx cuda_jit_sources) target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") # ------------------------ Compilation configs ------------------------ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") # Enable calling host constexpr functions from device. This is needed because # the constexpr version of isnan is host only. target_compile_options( mlx PRIVATE "$<$:--expt-relaxed-constexpr>") if(MSVC) # Ignore warnings from CUTLASS. target_compile_options( mlx PRIVATE $<$:-Xcudafe="--diag_suppress=2908">) else() # Required for generating optimized CUTLASS code. target_compile_options( mlx PRIVATE "$<$:-Xcompiler=-fno-strict-aliasing>") endif() # Suppress nvcc warnings on C++ headers. target_compile_options( mlx PRIVATE $<$:-Xcudafe="--diag_suppress=27,997,1394,20011,20208"> ) # Ignore some valid nvcc warnings, we might want to fix them in future. target_compile_options( mlx PRIVATE $<$:-Xcudafe="--diag_suppress=177,550">) # Use stronger binaries compression. This feature was introduced in CUDA 12.8 # and requires drivers released after CUDA 12.4. if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0) target_compile_options( mlx PRIVATE "$<$:--compress-mode=size>") endif() # Use native CUDA arch by default. if(NOT DEFINED MLX_CUDA_ARCHITECTURES) execute_process( COMMAND __nvcc_device_query OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE) if(MLX_CUDA_ARCHITECTURES STREQUAL "") message( FATAL_ERROR "Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES") elseif(MLX_CUDA_ARCHITECTURES GREATER_EQUAL 90) # Use arch-specific compute capability whenever possible. set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a") endif() endif() message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}") # Skip Hopper-only kernels when not building for sm90a. if(NOT DEFINED ENV{MLX_DISABLE_SM90A_KERNELS} AND (("90a" IN_LIST MLX_CUDA_ARCHITECTURES) OR ("90a-real" IN_LIST MLX_CUDA_ARCHITECTURES))) target_compile_definitions(mlx PRIVATE MLX_CUDA_SM90A_ENABLED) endif() # Search CUDA libs from installed python packages. if(WIN32) # Resolve paths of unfound DLL at runtime. if(BUILD_SHARED_LIBS) target_link_libraries(mlx PRIVATE "delayimp.lib") target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp) else() # For static library the delayload must be compiled into final executables. target_link_libraries(mlx PUBLIC "delayimp.lib") target_sources( mlx PUBLIC $) endif() # Get all the CUDA DLLs we could link with. file( GLOB CUDA_DLL_NAMES RELATIVE "${CUDAToolkit_BIN_DIR}/x64" "${CUDAToolkit_BIN_DIR}/x64/*.dll") # Delay load CUDA and cuDNN libs. foreach(CUDA_DLL ${CUDA_DLL_NAMES} ${CUDNN_DLL_NAMES}) target_link_options(mlx PUBLIC "/DELAYLOAD:${CUDA_DLL}") endforeach() # Pass the locations where CUDA DLLs are placed. if(NOT MLX_LOAD_CUDA_LIBS_FROM_PYTHON) target_compile_definitions( mlx PUBLIC MLX_CUDA_BIN_DIR="${CUDAToolkit_BIN_DIR}/x64" MLX_CUDNN_BIN_DIR="${CUDNN_BIN_DIR}") endif() else() # For POSIX we rely on RPATH to search for CUDA libs. if(MLX_LOAD_CUDA_LIBS_FROM_PYTHON) set_property( TARGET mlx APPEND PROPERTY INSTALL_RPATH # The paths here should match the install_requires in setup.py. "$ORIGIN/../../nvidia/cublas/lib" "$ORIGIN/../../nvidia/cuda_nvrtc/lib" "$ORIGIN/../../nvidia/cudnn/lib" "$ORIGIN/../../nvidia/nccl/lib") endif() endif() # ------------------------ Dependencies ------------------------ # Use fixed version of CCCL. FetchContent_Declare( cccl URL "https://github.com/NVIDIA/cccl/releases/download/v3.1.3/cccl-v3.1.3.zip") FetchContent_MakeAvailable(cccl) target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") # Install CCCL headers for JIT. install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) # The binary of C++ tests will not be installed so it can not find the CCCL # headers, and we have to hard-code the path. if(MLX_BUILD_TESTS) target_compile_definitions(mlx PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include") endif() # Use fixed version of NVTX. FetchContent_Declare( nvtx3 GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git GIT_TAG v3.1.1 GIT_SHALLOW TRUE SOURCE_SUBDIR c EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(nvtx3) target_link_libraries(mlx PUBLIC $) # Make cuda runtime APIs available in non-cuda files. target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) # Use cublasLt. target_link_libraries(mlx PRIVATE CUDA::cublasLt) # Use cuFFT. target_link_libraries(mlx PRIVATE CUDA::cufft) # Use NVRTC and driver APIs. target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) # Use the frontend APIs of cuDNN. FetchContent_Declare( cudnn GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git GIT_TAG v1.16.0 GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) set(CUDNN_FRONTEND_BUILD_SAMPLES OFF) set(CUDNN_FRONTEND_BUILD_TESTS OFF) set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF) FetchContent_MakeAvailable(cudnn) target_link_libraries(mlx PRIVATE cudnn_frontend) # Link with the actual cuDNN libraries. target_link_libraries(mlx PRIVATE CUDNN::cudnn_all) # Use header-only CUTLASS. FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git GIT_TAG v4.3.5 GIT_SHALLOW TRUE SOURCE_SUBDIR include EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(cutlass) target_include_directories( mlx SYSTEM PRIVATE $) ================================================ FILE: mlx/backend/cuda/allocator.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/memory.h" #include "mlx/scheduler.h" #include "mlx/utils.h" #include #include #include #include #include namespace mlx::core { namespace cu { constexpr int page_size = 16384; // Any allocations smaller than this will try to use the small pool constexpr int small_block_size = 8; // The small pool size in bytes. This should be a multiple of the host page // size and small_block_size. constexpr int small_pool_size = 4 * page_size; // Check if running on Windows or Windows Subsystem for Linux bool is_windows() { #if defined(_WIN32) return true; #elif defined(__linux__) // WSL kernels contain "microsoft" or "WSL" in /proc/version static bool is_wsl = []() { std::ifstream version("/proc/version"); if (version.is_open()) { std::string line; std::getline(version, line); return line.find("microsoft") != std::string::npos || line.find("Microsoft") != std::string::npos || line.find("WSL") != std::string::npos; } return false; }(); return is_wsl; #else return false; #endif } bool supports_managed_memory() { static bool managed_memory = []() { int device_count = gpu::device_count(); for (int i = 0; i < device_count; ++i) { auto& d = cu::device(i); if (!d.managed_memory()) { return false; } // Empirically on Windows (and WSL) if there is no concurrentManagedAccess // the managed memory also does not work. if (is_windows() && !d.concurrent_managed_access()) { return false; } } return true; }(); return managed_memory; } inline void* unified_malloc(size_t size) { void* data = nullptr; if (supports_managed_memory()) { CHECK_CUDA_ERROR(cudaMallocManaged(&data, size)); } else { CHECK_CUDA_ERROR(cudaMallocHost(&data, size)); } return data; } inline void unified_free(void* data) { if (supports_managed_memory()) { CHECK_CUDA_ERROR(cudaFree(data)); } else { CHECK_CUDA_ERROR(cudaFreeHost(data)); } } #if CUDART_VERSION >= 13000 inline cudaMemLocation cuda_mem_loc(int i) { cudaMemLocation loc; loc.type = cudaMemLocationTypeDevice; loc.id = i; return loc; } #else inline int cuda_mem_loc(int i) { return i; } #endif // CUDART_VERSION >= 13000 SmallSizePool::SmallSizePool() { auto num_blocks = small_pool_size / small_block_size; buffer_ = new Block[num_blocks]; next_free_ = buffer_; data_ = unified_malloc(small_pool_size); if (supports_managed_memory()) { int device_count = gpu::device_count(); for (int i = 0; i < device_count; ++i) { if (device(i).concurrent_managed_access()) { auto loc = cuda_mem_loc(i); CHECK_CUDA_ERROR(cudaMemAdvise( data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc)); } } } auto curr = next_free_; for (size_t i = 1; i < num_blocks; ++i) { curr->next = buffer_ + i; curr = curr->next; } curr->next = nullptr; } SmallSizePool::~SmallSizePool() { unified_free(data_); delete[] buffer_; } CudaBuffer* SmallSizePool::malloc() { if (next_free_ == nullptr) { return nullptr; } Block* b = next_free_; uint64_t i = next_free_ - buffer_; next_free_ = next_free_->next; b->buf.data = static_cast(data_) + i * small_block_size; b->buf.size = small_block_size; b->buf.device = -1; return &b->buf; } void SmallSizePool::free(CudaBuffer* buf) { auto b = reinterpret_cast(buf); b->next = next_free_; next_free_ = b; } bool SmallSizePool::in_pool(CudaBuffer* buf) { constexpr int num_blocks = (small_pool_size / small_block_size); auto b = reinterpret_cast(buf); int64_t block_num = b - buffer_; return block_num >= 0 && block_num < num_blocks; } CudaAllocator::CudaAllocator() : buffer_cache_( page_size, [](CudaBuffer* buf) { return buf->size; }, [this](CudaBuffer* buf) { free_cuda_buffer(buf); }) { size_t free; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_)); memory_limit_ = total_memory_ * 0.95; free_limit_ = total_memory_ - memory_limit_; max_pool_size_ = memory_limit_; int device_count = gpu::device_count(); free_streams_.resize(device_count); mem_pools_.resize(device_count); for (int i = 0; i < device_count; ++i) { auto& d = device(i); if (d.memory_pools()) { free_streams_[i] = CudaStream(d); CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i)); } } } Buffer CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) { if (size == 0) { return Buffer{new CudaBuffer{nullptr, 0, -1}}; } if (size <= small_block_size) { size = 8; } else if (size < page_size) { size = next_power_of_2(size); } else { size = page_size * ((size + page_size - 1) / page_size); } if (size <= small_block_size || stream == nullptr) { device = -1; } // Find available buffer from cache. std::unique_lock lock(mutex_); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { // If we have a lot of memory pressure try to reclaim memory from the cache. int64_t mem_to_free = get_active_memory() + get_cache_memory() + size - memory_limit_; if (mem_to_free > 0) { buffer_cache_.release_cached_buffers(mem_to_free); } // Try the scalar pool first if (size <= small_block_size) { buf = scalar_pool_.malloc(); } lock.unlock(); if (!buf) { void* data = nullptr; if (device == -1) { data = unified_malloc(size); } else { cu::device(device).make_current(); if (mem_pools_[device]) { // supports memory pools CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream)); } else { CHECK_CUDA_ERROR(cudaMalloc(&data, size)); } } if (!data) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; throw std::runtime_error(msg.str()); } buf = new CudaBuffer{data, size, device}; } lock.lock(); // If any cuda memory pool has too much reserved memory, clear some // memory from the cache. This prevents graph / kernel execution failing // from OOM if (get_cache_memory() > 0) { for (auto p : mem_pools_) { if (p) { size_t used = 0; CHECK_CUDA_ERROR(cudaMemPoolGetAttribute( p, cudaMemPoolAttrReservedMemCurrent, &used)); if (used > (total_memory_ - free_limit_)) { buffer_cache_.release_cached_buffers(free_limit_); break; } } } } } active_memory_ += buf->size; peak_memory_ = std::max(active_memory_, peak_memory_); // Maintain the cache below the requested limit. if (get_cache_memory() > max_pool_size_) { buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); } lock.unlock(); // Copy to unified memory here if the buffer is not on the right device. if (buf->device >= 0 && buf->device != device) { move_to_unified_memory(*buf, stream); } return Buffer{buf}; } Buffer CudaAllocator::malloc(size_t size) { return malloc_async(size, -1, nullptr); } void CudaAllocator::free(Buffer buffer) { auto* buf = static_cast(buffer.ptr()); if (!buf) { return; } if (buf->size == 0) { delete buf; return; } std::unique_lock lock(mutex_); active_memory_ -= buf->size; if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { free_cuda_buffer(buf); } } size_t CudaAllocator::size(Buffer buffer) const { auto* buf = static_cast(buffer.ptr()); if (!buf) { return 0; } return buf->size; } void CudaAllocator::move_to_unified_memory( CudaBuffer& buf, cudaStream_t stream) { if (buf.device == -1) { return; } void* data = unified_malloc(buf.size); cudaMemcpyKind kind = supports_managed_memory() ? cudaMemcpyDefault : cudaMemcpyDeviceToHost; if (stream && mem_pools_[buf.device]) { CHECK_CUDA_ERROR(cudaMemcpyAsync(data, buf.data, buf.size, kind, stream)); free_async(buf, stream); } else { CHECK_CUDA_ERROR(cudaMemcpy(data, buf.data, buf.size, kind)); free_async(buf); } buf.data = data; buf.device = -1; } // This must be called with mutex_ aquired void CudaAllocator::free_cuda_buffer(CudaBuffer* buf) { if (scalar_pool_.in_pool(buf)) { scalar_pool_.free(buf); } else { free_async(*buf); delete buf; } } void CudaAllocator::free_async(CudaBuffer& buf, cudaStream_t stream) { if (buf.device == -1) { unified_free(buf.data); } else { // Free asynchronously when memory pools is supported. if (mem_pools_[buf.device]) { if (!stream) { stream = free_streams_[buf.device]; } CHECK_CUDA_ERROR(cudaFreeAsync(buf.data, stream)); } else { CHECK_CUDA_ERROR(cudaFree(buf.data)); } } } size_t CudaAllocator::get_active_memory() const { return active_memory_; } size_t CudaAllocator::get_peak_memory() const { return peak_memory_; } void CudaAllocator::reset_peak_memory() { std::lock_guard lock(mutex_); peak_memory_ = 0; } size_t CudaAllocator::get_memory_limit() { return memory_limit_; } size_t CudaAllocator::set_memory_limit(size_t limit) { std::lock_guard lock(mutex_); std::swap(limit, memory_limit_); return limit; } size_t CudaAllocator::get_cache_memory() const { return buffer_cache_.cache_size(); } size_t CudaAllocator::set_cache_limit(size_t limit) { std::lock_guard lk(mutex_); std::swap(limit, max_pool_size_); return limit; } void CudaAllocator::clear_cache() { std::lock_guard lk(mutex_); buffer_cache_.clear(); } CudaAllocator& allocator() { static auto* allocator_ = []() { // Ensure scheduler is created before allocator. scheduler::scheduler(); // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This // can save some time at program exit. return new CudaAllocator(); }(); return *allocator_; } Buffer malloc_async(size_t size, CommandEncoder& encoder) { return allocator().malloc_async( size, encoder.device().cuda_device(), encoder.stream()); } } // namespace cu namespace allocator { Allocator& allocator() { return cu::allocator(); } void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } auto& cbuf = *static_cast(ptr_); cu::allocator().move_to_unified_memory(cbuf); return cbuf.data; } } // namespace allocator size_t get_active_memory() { return cu::allocator().get_active_memory(); } size_t get_peak_memory() { return cu::allocator().get_peak_memory(); } void reset_peak_memory() { return cu::allocator().reset_peak_memory(); } size_t set_memory_limit(size_t limit) { return cu::allocator().set_memory_limit(limit); } size_t get_memory_limit() { return cu::allocator().get_memory_limit(); } size_t get_cache_memory() { return cu::allocator().get_cache_memory(); } size_t set_cache_limit(size_t limit) { return cu::allocator().set_cache_limit(limit); } void clear_cache() { cu::allocator().clear_cache(); } // Not supported in CUDA. size_t set_wired_limit(size_t) { return 0; } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/allocator.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/cuda/cuda_utils.h" #include #include #include #include namespace mlx::core::cu { class CommandEncoder; using allocator::Buffer; // Stores cuda-managed unified memory. struct CudaBuffer { void* data; size_t size; int device; // -1 for managed }; class SmallSizePool { private: union Block { Block* next; CudaBuffer buf; }; Block* buffer_{nullptr}; void* data_{nullptr}; Block* next_free_{nullptr}; public: SmallSizePool(); ~SmallSizePool(); SmallSizePool(const SmallSizePool&) = delete; SmallSizePool& operator=(const SmallSizePool&) = delete; CudaBuffer* malloc(); void free(CudaBuffer* buf); bool in_pool(CudaBuffer* buf); }; class CudaAllocator : public allocator::Allocator { public: Buffer malloc(size_t size) override; Buffer malloc_async(size_t size, int device, cudaStream_t stream); void free(Buffer buffer) override; size_t size(Buffer buffer) const override; // Replace the memory of |buf| with unified memory (managed memory or pinned // host memory), and copy the data over. Pass |stream| to copy asynchronously. void move_to_unified_memory(CudaBuffer& buf, cudaStream_t stream = nullptr); size_t get_active_memory() const; size_t get_peak_memory() const; void reset_peak_memory(); size_t get_memory_limit(); size_t set_memory_limit(size_t limit); size_t get_cache_memory() const; size_t set_cache_limit(size_t limit); void clear_cache(); private: void free_cuda_buffer(CudaBuffer* buf); void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr); CudaAllocator(); friend CudaAllocator& allocator(); std::mutex mutex_; size_t memory_limit_; size_t free_limit_; size_t total_memory_; size_t max_pool_size_; BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; std::vector free_streams_; std::vector mem_pools_; SmallSizePool scalar_pool_; }; CudaAllocator& allocator(); Buffer malloc_async(size_t size, CommandEncoder& encoder); } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/arange.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void arange(T* out, IdxT size, T start, T step) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_WRITES > size) { for (IdxT i = index * N_WRITES; i < size; ++i) { out[i] = start + i * step; } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_WRITES; ++i) { out_vec[i] = start + (index * N_WRITES + i) * step; } store_vector(out, index, out_vec); } } } // namespace cu void Arange::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Arange::eval_gpu"); if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(stream()); out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_output_array(out); dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; constexpr int N_WRITES = 16 / sizeof(OutType); dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES); encoder.add_kernel_node( cu::arange, num_blocks, block_dims, gpu_ptr(out), out.data_size(), static_cast(start_), static_cast(start_ + step_) - static_cast(start_)); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/arg_reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template struct IndexValPair { uint32_t index; T val; }; template struct ArgMin { constexpr __device__ T init() { return Limits::max(); } __device__ IndexValPair operator()( const IndexValPair& best, const IndexValPair& current) { if (best.val > current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template __device__ IndexValPair reduce_many( IndexValPair best, const AlignedVector& vals, uint32_t offset) { #pragma unroll for (int i = 0; i < N; i++) { if (vals[i] < best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template struct ArgMax { constexpr __device__ T init() { return Limits::min(); } __device__ IndexValPair operator()( const IndexValPair& best, const IndexValPair& current) { if (best.val < current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template __device__ IndexValPair reduce_many( IndexValPair best, const AlignedVector& vals, uint32_t offset) { #pragma unroll for (int i = 0; i < N; i++) { if (vals[i] > best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template __global__ void arg_reduce_general( const T* in, uint32_t* out, size_t size, const __grid_constant__ Shape shape, const __grid_constant__ Strides in_strides, const __grid_constant__ Strides out_strides, int32_t ndim, int64_t axis_stride, int32_t axis_size) { auto block = cg::this_thread_block(); int64_t index = cg::this_grid().block_rank(); if (index >= size) { return; } int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); in += in_idx; Op op; T init = op.init(); IndexValPair best{0, init}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto tid = r * BLOCK_DIM + block.thread_index().x; auto vals = load_vector(in, tid, axis_size, axis_stride, init); best = op.reduce_many(best, vals, tid * N_READS); } typedef cub::BlockReduce, BLOCK_DIM> BlockReduceT; __shared__ typename BlockReduceT::TempStorage temp; best = BlockReduceT(temp).Reduce(best, op); if (block.thread_rank() == 0) { out[out_idx] = best.index; } } } // namespace cu void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgReduce::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); // Prepare the shapes, strides and axis arguments. Shape shape = remove_index(in.shape(), axis_); Strides in_strides = remove_index(in.strides(), axis_); Strides out_strides = out.ndim() == in.ndim() ? remove_index(out.strides(), axis_) : out.strides(); int64_t axis_stride = in.strides()[axis_]; int32_t axis_size = in.shape()[axis_]; int32_t ndim = shape.size(); // ArgReduce. encoder.set_input_array(in); encoder.set_output_array(out); dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { using T = cuda_type_t; constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); auto kernel = cu::arg_reduce_general, block_dim(), N_READS>; if (reduce_type_ == ArgReduce::ArgMin) { kernel = cu::arg_reduce_general, block_dim(), N_READS>; } encoder.add_kernel_node( kernel, num_blocks, block_dim(), gpu_ptr(in), gpu_ptr(out), out.size(), const_param(shape), const_param(in_strides), const_param(out_strides), ndim, axis_stride, axis_size); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/bin2h.cmake ================================================ # Based on: https://github.com/sivachandran/cmake-bin2h # # Copyright 2020 Sivachandran Paramasivam # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. include(CMakeParseArguments) # Function to wrap a given string into multiple lines at the given column # position. # # Parameters: # # * VARIABLE - The name of the CMake variable holding the string. # * AT_COLUMN - The column position at which string will be wrapped. function(WRAP_STRING) set(oneValueArgs VARIABLE AT_COLUMN) cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN}) string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength) math(EXPR offset "0") while(stringLength GREATER 0) if(stringLength GREATER ${WRAP_STRING_AT_COLUMN}) math(EXPR length "${WRAP_STRING_AT_COLUMN}") else() math(EXPR length "${stringLength}") endif() string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line) set(lines "${lines}\n ${line}") math(EXPR stringLength "${stringLength} - ${length}") math(EXPR offset "${offset} + ${length}") endwhile() set(${WRAP_STRING_VARIABLE} "${lines}" PARENT_SCOPE) endfunction() # Function to embed contents of a file as byte array in C/C++ header file(.h). # The header file will contain a byte array and integer variable holding the # size of the array. # # Parameters: # # * SOURCE_FILES - The paths of source files whose contents will be embedded in # the header file. # * VARIABLE_NAME - The name of the variable for the byte array. The string # "_SIZE" will be append to this name and will be used a variable name for # size variable. # * HEADER_FILE - The path of header file. # * APPEND - If specified appends to the header file instead of overwriting it # * HEADER_NAMESPACE - The namespace, where the array should be located in. # * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte # array. # # Usage: # # bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG") function(BIN2H) set(options APPEND NULL_TERMINATE) set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE) set(multiValueArgs SOURCE_FILES) cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(arrayDefinition "") foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES) # get filename without extension get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE) # convert the filename to a valid C identifier string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME) # reads source file contents as hex string file(READ ${SOURCE_FILE} hexString HEX) # append null if(BIN2H_NULL_TERMINATE) string(APPEND hexString "00") endif() # wraps the hex string into multiple lines wrap_string(VARIABLE hexString AT_COLUMN 24) # strip the © in source code string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString}) string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues ${arrayValues}) # make a full variable name for the array set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}") # declares byte array and the length variables string(APPEND arrayDefinition "constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n") endforeach() # add namespace wrapper if defined if(DEFINED BIN2H_HEADER_NAMESPACE) set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {") set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}") set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n") endif() set(arrayIncludes "#pragma once") string(PREPEND declarations "${arrayIncludes}\n\n") if(BIN2H_APPEND) file(APPEND ${BIN2H_HEADER_FILE} "${declarations}") else() file(WRITE ${BIN2H_HEADER_FILE} "${declarations}") endif() endfunction() # ----------------------------- CLI args ----------------------------- string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) foreach(source ${MLX_JIT_SOURCES_LIST}) list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}") endforeach() bin2h( SOURCE_FILES ${MLX_JIT_SOURCES_ABS} NULL_TERMINATE VARIABLE_NAME "jit_source" HEADER_NAMESPACE "mlx::core" HEADER_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h") ================================================ FILE: mlx/backend/cuda/binary/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu) ================================================ FILE: mlx/backend/cuda/binary/add.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Add) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/arctan2.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(ArcTan2) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/binary.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; constexpr int BINARY_MAX_BLOCK_DIM = 1024; template __global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_ss( const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (int i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[0], b[0]); } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a[0], b[0]); } store_vector(out, index, out_vec); } } template __global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_sv( const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[0], b[i]); } } else { auto b_vec = load_vector(b, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a[0], b_vec[i]); } store_vector(out, index, out_vec); } } template __global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_vs( const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[i], b[0]); } } else { auto a_vec = load_vector(a, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b[0]); } store_vector(out, index, out_vec); } } template __global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_vv( const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[i], b[i]); } } else { auto a_vec = load_vector(a, index); auto b_vec = load_vector(b, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out, index, out_vec); } } template < typename Op, typename In, typename Out, typename IdxT, int NDIM, int N_READS> __global__ void binary_g_nd( const In* a, const In* b, Out* out, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto a_stride_x = a_strides[NDIM - 1]; auto b_stride_x = b_strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc_nd( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template __global__ void binary_g( const In* a, const In* b, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto a_stride_x = a_strides[ndim - 1]; auto b_stride_x = b_strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), ndim); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template constexpr bool supports_binary_op() { if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_integral_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } return false; } } // namespace cu template void binary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::binary_g_nd< Op, InType, OutType, IdxT, dims_constant(), 1>; if (work_per_thread == 4) { kernel = cu::binary_g_nd< Op, InType, OutType, IdxT, dims_constant(), 4>; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { auto kernel = cu::binary_g; if (work_per_thread == 4) { kernel = cu::binary_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides), ndim); } }); } else { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large(), N_READS, cu::BINARY_MAX_BLOCK_DIM); encoder.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), out.data_size()); }); } } else { throw std::runtime_error( fmt::format( "Can not do binary op {} on inputs of {} with result of {}.", op, dtype_to_string(a.dtype()), dtype_to_string(out.dtype()))); } }); }); } template void binary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); auto& encoder = cu::get_command_encoder(s); set_binary_op_output_data( a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); }); binary_op_gpu_inplace(inputs, out, op, s); } #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ nvtx3::scoped_range r(#func "::eval_gpu"); \ auto& s = out.primitive().stream(); \ binary_op_gpu(inputs, out, name(), s); \ } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/bitwise_binary.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); switch (op_) { case BitwiseBinary::And: binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Or: binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Xor: binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::LeftShift: binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::RightShift: binary_op_gpu(inputs, out, name(), s); break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/divide.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Divide) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/equal.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { void Equal::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Equal::eval_gpu"); auto& s = out.primitive().stream(); if (equal_nan_) { binary_op_gpu(inputs, out, name(), s); } else { binary_op_gpu(inputs, out, name(), s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/greater.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Greater) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/greater_equal.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(GreaterEqual) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/less.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Less) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/less_equal.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(LessEqual) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/log_add_exp.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(LogAddExp) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/logical_and.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(LogicalAnd) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/logical_or.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(LogicalOr) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/maximum.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Maximum) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/minimum.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Minimum) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/multiply.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Multiply) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/not_equal.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(NotEqual) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/power.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Power) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/remainder.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Remainder) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary/subtract.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/binary/binary.cuh" namespace mlx::core { BINARY_GPU(Subtract) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/binary_two.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[0], b[0]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a[0], b[0]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[0], b[i]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto b_vec = load_vector(b, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a[0], b_vec[i]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[i], b[0]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b[0]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[i], b[i]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); auto b_vec = load_vector(b, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template < typename Op, typename In, typename Out, typename IdxT, int NDIM, int N_READS> __global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto a_stride_x = a_strides[NDIM - 1]; auto b_stride_x = b_strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc_nd( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec_a; AlignedVector out_vec_b; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_vec_a[i] = out[0]; out_vec_b[i] = out[1]; } store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template __global__ void binary_two_g( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto a_stride_x = a_strides[ndim - 1]; auto b_stride_x = b_strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), ndim); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec_a; AlignedVector out_vec_b; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_vec_a[i] = out[0]; out_vec_b[i] = out[1]; } store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template constexpr bool supports_binary_two_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); } return false; } } // namespace cu template void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; auto& out_a = outputs[0]; auto& out_b = outputs[1]; auto bopt = get_binary_op_type(a, b); auto& encoder = cu::get_command_encoder(s); set_binary_op_output_data( a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); }); set_binary_op_output_data( a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); }); if (out_a.size() == 0) { return; } encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_two_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out_a); auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out_a.size() / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::binary_two_g_nd< Op, InType, OutType, IdxT, dims_constant(), 1>; if (work_per_thread == 4) { kernel = cu::binary_two_g_nd< Op, InType, OutType, IdxT, dims_constant(), 4>; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { auto kernel = cu::binary_two_g; if (work_per_thread == 4) { kernel = cu::binary_two_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), const_param(b_strides), ndim); } }); } else { dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_two_sv; } else if (bopt == BinaryOpType::VectorScalar) { kernel = cu::binary_two_vs; } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( out_a.data_size(), out_a.shape(), out_a.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), out_a.data_size()); }); } } else { throw std::runtime_error( fmt::format( "Can not do binary op {} on inputs of {} with result of {}.", op, dtype_to_string(a.dtype()), dtype_to_string(out_a.dtype()))); } }); }); } template void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); binary_two_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); binary_two_op_gpu(inputs, outputs, name(), s); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/compiled.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/graph_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { struct FusedKernelBuilder { std::string os; const std::string& kernel_name; const std::vector& inputs; const std::vector& outputs; const std::vector& tape; const std::function& is_constant; void build(const char* name, bool contiguous) { NodeNamer namer; // Function parameters. std::vector params; for (size_t i = 0; i < inputs.size(); ++i) { if (is_constant(i)) { continue; } const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); params.push_back( fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname)); if (!is_scalar(x) && !contiguous) { params.push_back( fmt::format( "const __grid_constant__ cuda::std::array {}_strides", xname)); } } for (const auto& x : outputs) { params.push_back( fmt::format( "{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x))); } if (!contiguous) { params.push_back( "const __grid_constant__ cuda::std::array shape"); } params.push_back("IdxT size"); // Build function signature. if (contiguous) { os += "template \n"; } else { os += "template \n"; } os += fmt::format("__global__ void {}(\n", kernel_name + name); for (size_t i = 0; i < params.size(); ++i) { os += " "; os += params[i]; if (i != params.size() - 1) { os += ",\n"; } } os += ") {\n"; // Index. For non contiguous kernels we create a separate index // variable per variable otherwise everyone uses `index`. os += " IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n" " if (index >= size) {\n" " return;\n" " }\n"; if (!contiguous) { for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); if (is_scalar(x) || is_constant(i)) { continue; } os += " IdxT " + xname + "_idx = 0;\n"; } os += " {\n"; os += " IdxT loc = index;\n"; os += " #pragma unroll\n" " for (int i = NDIM - 1; i >= 0; i--) {\n"; for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); if (is_scalar(x) || is_constant(i)) { continue; } os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname + "_strides[i]);\n"; } os += " loc /= shape[i];\n" " }\n" " }\n"; } // Vectorized read loop if (contiguous) { for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; if (is_scalar(x) || is_constant(i)) { continue; } const std::string& xname = namer.get_name(x); std::string type = dtype_to_cuda_type(x.dtype()); os += fmt::format( " auto vec_{0} = load_vector({0} + index, 0, size - index, 0);\n", xname, type); } } // Create some space for the outputs for (const auto& x : outputs) { const std::string& xname = namer.get_name(x); std::string type = dtype_to_cuda_type(x.dtype()); os += fmt::format( " AlignedVector<{}, work_per_thread> vec_{};\n", type, xname); } // Work loop if (!contiguous) { os += "\n" " for (int i = 0; i < work_per_thread && index < size; i++) {\n"; } else { os += "\n" " #pragma unroll\n" " for (int i = 0; i < work_per_thread; i++) {\n"; } // Read inputs. for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); std::string type = dtype_to_cuda_type(x.dtype()); std::string value; if (is_constant(i)) { std::ostringstream ss; print_constant(ss, x); value = fmt::format("static_cast<{}>({})", type, ss.str()); } else if (is_scalar(x)) { value = fmt::format("{}[0]", xname); } else if (contiguous) { value = fmt::format("vec_{}[i]", xname); } else { value = fmt::format("{}[{}_idx]", xname, xname); } os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); } // Write tape. for (const auto& x : tape) { const std::string& xname = namer.get_name(x); std::string type = dtype_to_cuda_type(x.dtype()); std::string value; if (is_static_cast(x.primitive())) { value = fmt::format( "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); } else { value = x.primitive().name(); value += "{}("; for (size_t i = 0; i < x.inputs().size() - 1; ++i) { value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); } value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); } os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); } // Write output. for (const auto& x : outputs) { os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x)); } // End of work loop if (!contiguous) { os += "\n"; for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); if (is_scalar(x) || is_constant(i)) { continue; } os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname); } } os += " }\n"; // Store the output to global memory for (const auto& x : outputs) { os += fmt::format( " store_vector({0} + index, 0, vec_{0}, size - index);\n", namer.get_name(x)); } os += "}\n"; } }; } // namespace cu constexpr const char* g_jit_includes = R"( #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include #define inf cuda::std::numeric_limits::infinity() )"; void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("Compiled::eval_gpu"); auto& s = stream(); // Determine the work per thread for the vectorized reads/writes. We take it // as 16 over the max itemsize for the outputs. Another heuristic could be // over the max itemsize of all arrays. int max_size = 1; for (const auto& x : outputs) { max_size = (max_size > x.itemsize()) ? max_size : x.itemsize(); } int work_per_thread = 16 / max_size; cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { // Build source code. cu::FusedKernelBuilder builder{ g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; builder.os += "namespace mlx::core::cu {\n\n" "namespace cg = cooperative_groups;\n\n"; builder.build("_contiguous", true); builder.os += "\n"; builder.build("_strided", false); builder.os += "\n} // namespace mlx::core::cu\n"; // Build kernel names. std::vector kernel_names; kernel_names.push_back( fmt::format( "mlx::core::cu::{}_contiguous", lib_name(), work_per_thread)); kernel_names.push_back( fmt::format( "mlx::core::cu::{}_contiguous", lib_name(), work_per_thread)); for (int wpt : {1, work_per_thread}) { for (int i = 1; i <= MAX_NDIM; ++i) { kernel_names.push_back( fmt::format( "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt)); kernel_names.push_back( fmt::format( "mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt)); } } return std::make_tuple( false, std::move(builder.os), std::move(kernel_names)); }); // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. auto [contiguous, shape, strides_vec] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); // Whether to use large index. bool large = compiled_use_large_index(inputs, outputs, contiguous); cu::KernelArgs args; // Put inputs. int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { if (is_constant_(i)) { continue; } const auto& x = inputs[i]; args.append(x); if (!contiguous && !is_scalar(x)) { args.append_ptr(strides_vec[strides_index++].data()); } } auto& encoder = cu::get_command_encoder(s); // Put outputs. compiled_allocate_outputs( inputs, outputs, is_constant_, contiguous, [&](auto n) { return cu::malloc_async(n, encoder); }); for (auto& x : outputs) { args.append(x); } // Put shape and size. if (!contiguous) { args.append_ptr(shape.data()); } if (large) { args.append(outputs[0].data_size()); } else { args.append(outputs[0].data_size()); } // Choose work per thread if (!contiguous && shape.back() % work_per_thread != 0) { work_per_thread = 1; } // Launch kernel. const char* index_type = large ? "int64_t" : "uint32_t"; std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); if (contiguous) { kernel_name += fmt::format("_contiguous<{}, {}>", index_type, work_per_thread); } else { kernel_name += fmt::format( "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread); } for (const auto& in : inputs) { encoder.set_input_array(in); } for (const auto& out : outputs) { encoder.set_output_array(out); } auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name); auto [num_blocks, block_dims] = get_launch_args(outputs[0], large, work_per_thread, max_block_dims); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/conv/conv.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/copy.h" namespace mlx::core { template struct ConvParams { int N; // Batch size int C; // In channels int O; // Out channels int strides[NDIM]; int padding[NDIM]; int kernel_dilation[NDIM]; int input_dilation[NDIM]; int groups; bool flip; int in_spatial_dims[NDIM]; int wt_spatial_dims[NDIM]; int out_spatial_dims[NDIM]; int64_t in_strides[NDIM + 2]; ConvParams( const array& in, const array& wt, const array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, int groups, bool flip) : N(in.shape(0)), C(in.shape(-1)), O(wt.shape(0)), groups(groups), flip(flip) { std::copy_n(strides.begin(), NDIM, this->strides); std::copy_n(padding.begin(), NDIM, this->padding); std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); } }; void gemm_grouped_conv( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, int groups, bool flip, Stream s); void gemm_conv( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, bool flip, Stream s); inline void gemm_conv( cu::CommandEncoder& encoder, array in, array wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, int groups, bool flip, Stream s) { if (!in.flags().row_contiguous) { in = contiguous_copy_gpu(in, s); encoder.add_temporary(in); } if (!wt.flags().row_contiguous) { wt = contiguous_copy_gpu(wt, s); encoder.add_temporary(wt); } if (groups == 1) { gemm_conv( encoder, in, wt, out, strides, padding, kernel_dilation, input_dilation, flip, s); } else { gemm_grouped_conv( encoder, in, wt, out, strides, padding, kernel_dilation, input_dilation, groups, flip, s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/conv/gemm_conv.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void naive_unfold_nd( const T* in, T* out, int filter_size, int out_pixels, const __grid_constant__ ConvParams params) { auto block = cg::this_thread_block(); auto tid = block.group_index(); auto lid = block.thread_index(); int index_batch = tid.z / out_pixels; // [0, N) int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) int index_wt_spatial = tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) if (index_wt_spatial >= filter_size / params.C) { return; } in += tid.y; // [0, C) out += tid.z * filter_size + index_wt_spatial * params.C + tid.y; bool valid = index_batch < params.N; // Get the coordinates in input. int index_in[NDIM] = {}; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int index_out = index_out_spatial % params.out_spatial_dims[i]; int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; if (params.flip) { index_wt = params.wt_spatial_dims[i] - index_wt - 1; } int index = index_out * params.strides[i] - params.padding[i] + index_wt * params.kernel_dilation[i]; int index_max = 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); valid &= (index >= 0) && (index < index_max) && (index % params.input_dilation[i] == 0); index_in[i] = index / params.input_dilation[i]; index_out_spatial /= params.out_spatial_dims[i]; index_wt_spatial /= params.wt_spatial_dims[i]; } if (valid) { int in_offset = index_batch * params.in_strides[0]; #pragma unroll for (int i = 0; i < NDIM; ++i) { in_offset += index_in[i] * params.in_strides[i + 1]; } *out = in[in_offset]; } else { *out = T{0}; } } } // namespace cu template array unfold_inputs_nd( cu::CommandEncoder& encoder, const array& in, int mat_M, int mat_K, int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder)); encoder.add_temporary(unfolded); int filter_size = params.C; #pragma unroll for (int i = 0; i < NDIM; ++i) { filter_size *= params.wt_spatial_dims[i]; } int out_pixels = 1; #pragma unroll for (int i = 0; i < NDIM; ++i) { out_pixels *= params.out_spatial_dims[i]; } int wt_spatial_size = mat_K / params.C; dim3 block_dims; block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); dim3 num_blocks; num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); num_blocks.y = params.C; num_blocks.z = mat_M; encoder.set_input_array(in); encoder.set_output_array(unfolded); dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { using DataType = cuda_type_t; encoder.add_kernel_node( cu::naive_unfold_nd, num_blocks, block_dims, gpu_ptr(in), gpu_ptr(unfolded), filter_size, out_pixels, params); }); return unfolded; } template void gemm_conv_nd( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, ConvParams& params, Stream s) { // Get gemm shapes. int mat_M = out.size() / params.O; // N * H_out * W_out int mat_K = wt.size() / params.O; // C * H_wt * W_wt int mat_N = params.O; // O // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. array in_unfolded = unfold_inputs_nd(encoder, in, mat_M, mat_K, mat_N, params); // Reshape weight to (C * H_wt * W_wt, O) for gemm. array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {}); wt_reshaped.copy_shared_buffer( wt, {1, mat_K}, {false, false, /* col_contiguous */ true}, wt.data_size()); // Single batch. Shape batch_shape{1}; Strides a_batch_strides{0}; Strides b_batch_strides{0}; // Run matmul. CublasGemm gemm( encoder.device(), in.dtype(), false, // a_transposed mat_M, // a_rows mat_K, // a_cols mat_K, // lda true, // b_transposed mat_K, // b_rows mat_N, // b_cols mat_K, // ldb batch_shape.back(), a_batch_strides.back(), b_batch_strides.back()); gemm.run( encoder, out, in_unfolded, wt_reshaped, batch_shape, a_batch_strides, b_batch_strides); } void gemm_conv( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, bool flip, Stream s) { int conv_ndim = in.ndim() - 2; if (conv_ndim < 1 || conv_ndim > 3) { throw std::runtime_error( fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); } dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { ConvParams params( in, wt, out, strides, padding, kernel_dilation, input_dilation, 1, // groups flip); gemm_conv_nd(encoder, in, wt, out, params, s); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/conv/gemm_grouped_conv.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void naive_grouped_unfold_transpose_nd( const T* in, T* out, int filter_size, int out_pixels, const __grid_constant__ ConvParams params) { auto block = cg::this_thread_block(); auto tid = block.group_index(); auto lid = block.thread_index(); int index_batch = tid.z / out_pixels; // [0, N) int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) int index_wt_spatial = tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) if (index_wt_spatial >= filter_size / params.C) { return; } in += tid.y; // [0, C) out += tid.z * filter_size + tid.y * (filter_size / params.C); bool valid = index_batch < params.N; // Get the coordinates in input. int index_in[NDIM] = {}; int wt_stride = 1; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int index_out = index_out_spatial % params.out_spatial_dims[i]; int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; out += index_wt * wt_stride; if (params.flip) { index_wt = params.wt_spatial_dims[i] - index_wt - 1; } int index = index_out * params.strides[i] - params.padding[i] + index_wt * params.kernel_dilation[i]; int index_max = 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); valid &= (index >= 0) && (index < index_max) && (index % params.input_dilation[i] == 0); index_in[i] = index / params.input_dilation[i]; index_out_spatial /= params.out_spatial_dims[i]; index_wt_spatial /= params.wt_spatial_dims[i]; wt_stride *= params.wt_spatial_dims[i]; } if (valid) { int in_offset = index_batch * params.in_strides[0]; #pragma unroll for (int i = 0; i < NDIM; ++i) { in_offset += index_in[i] * params.in_strides[i + 1]; } *out = in[in_offset]; } else { *out = T{0}; } } } // namespace cu template array grouped_unfold_transpose_inputs_nd( cu::CommandEncoder& encoder, const array& in, int mat_M, int mat_K, int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder)); encoder.add_temporary(unfolded); int filter_size = params.C; #pragma unroll for (int i = 0; i < NDIM; ++i) { filter_size *= params.wt_spatial_dims[i]; } int out_pixels = 1; #pragma unroll for (int i = 0; i < NDIM; ++i) { out_pixels *= params.out_spatial_dims[i]; } int wt_spatial_size = (mat_K * params.groups) / params.C; dim3 block_dims; block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); dim3 num_blocks; num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); num_blocks.y = params.C; num_blocks.z = mat_M; encoder.set_input_array(in); encoder.set_output_array(unfolded); dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { using DataType = cuda_type_t; encoder.add_kernel_node( cu::naive_grouped_unfold_transpose_nd, num_blocks, block_dims, gpu_ptr(in), gpu_ptr(unfolded), filter_size, out_pixels, params); }); return unfolded; } template void gemm_grouped_conv_nd( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, ConvParams& params, Stream s) { // Get gemm shapes. int C_per_group = params.C / params.groups; int O_per_group = params.O / params.groups; int mat_M = out.size() / params.O; // N * H_out * W_out int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt int mat_N = O_per_group; // O_per_group // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. array in_unfolded = grouped_unfold_transpose_inputs_nd( encoder, in, mat_M, mat_K, mat_N, params); // Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm. int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1); array wt_view( {params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {}); wt_view.copy_shared_buffer( wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); array wt_reshaped = contiguous_copy_gpu(wt_view, s); // Batch with size of groups. Shape batch_shape{params.groups}; Strides a_batch_strides{mat_K}; Strides b_batch_strides{mat_N * mat_K}; // Run matmul. CublasGemm gemm( encoder.device(), in.dtype(), false, // a_transposed mat_M, // a_rows mat_K, // a_cols mat_K * params.groups, // lda true, // b_transposed mat_K, // b_rows mat_N, // b_cols mat_K, // ldb batch_shape.back(), a_batch_strides.back(), b_batch_strides.back()); gemm.set_out( out.dtype(), false, // out_transposed mat_M, // out_rows mat_N, // out_cols mat_N * params.groups, // out_ld params.groups, // batch_count mat_N); // batch_stride gemm.run( encoder, out, in_unfolded, wt_reshaped, batch_shape, a_batch_strides, b_batch_strides); } void gemm_grouped_conv( cu::CommandEncoder& encoder, const array& in, const array& wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, int groups, bool flip, Stream s) { int conv_ndim = in.ndim() - 2; if (conv_ndim < 1 || conv_ndim > 3) { throw std::runtime_error( fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); } dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { ConvParams params( in, wt, out, strides, padding, kernel_dilation, input_dilation, groups, flip); gemm_grouped_conv_nd(encoder, in, wt, out, params, s); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/conv.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace { enum ConvBackendType { CONV_FALLBACK, CONV_FORWARD, CONV_BACKWARD_INPUT, CONV_BACKWARD_WEIGHT, }; struct ConvCacheKey { int device_id; fe::DataType_t cudnn_dtype; std::array input_shape; std::array weight_shape; std::array stride; std::array padding_lo; std::array padding_hi; std::array dilation; int groups; bool flip; uint8_t input_alignment; uint8_t weight_alignment; uint8_t output_alignment; }; auto& conv_cache() { static LRUBytesKeyCache< ConvCacheKey, std::pair>> cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128); return cache; } auto get_conv_settings( ConvBackendType backend_type, array& x, array& w, array& y, const std::vector& kernel_strides, const std::vector& padding_lo_, const std::vector& padding_hi_, const std::vector& kernel_dilation, const std::vector& input_dilation) { auto padding_lo = convert_vector(padding_lo_); auto padding_hi = convert_vector(padding_hi_); if (backend_type == CONV_BACKWARD_INPUT) { for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1); padding_lo[i] = wt_size - padding_lo[i] - 1; int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1); int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1); padding_hi[i] = out_size - in_size + padding_hi[i]; } return std::make_tuple( convert_vector(input_dilation), std::move(padding_lo), std::move(padding_hi), convert_vector(kernel_dilation)); } else if (backend_type == CONV_BACKWARD_WEIGHT) { padding_hi = padding_lo; return std::make_tuple( convert_vector(kernel_dilation), std::move(padding_lo), std::move(padding_hi), convert_vector(kernel_strides)); } else { return std::make_tuple( convert_vector(kernel_strides), std::move(padding_lo), std::move(padding_hi), convert_vector(kernel_dilation)); } } std::optional build_conv_graph( cu::CommandEncoder& encoder, ConvBackendType backend_type, Dtype dtype, array& x, array& w, array& y, const std::vector& stride, const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& dilation) { auto compute_dtype = (dtype == float16 || dtype == bfloat16) ? float32 : dtype; DnnGraph graph(encoder.device().get_cudnn_handle(), dtype, compute_dtype); auto x_ = graph.tensor_nchw("X", 'x', x); auto w_ = graph.tensor_nchw("W", 'w', w); auto set_options = [&](auto& options) { options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype)) .set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION) .set_stride(stride) .set_pre_padding(padding_lo) .set_post_padding(padding_hi) .set_dilation(dilation); }; std::shared_ptr y_; if (backend_type == CONV_FORWARD) { auto options = fe::graph::Conv_fprop_attributes(); set_options(options); y_ = graph.conv_fprop(x_, w_, options); } else if (backend_type == CONV_BACKWARD_INPUT) { auto options = fe::graph::Conv_dgrad_attributes(); set_options(options); y_ = graph.conv_dgrad(x_, w_, options); } else if (backend_type == CONV_BACKWARD_WEIGHT) { auto options = fe::graph::Conv_wgrad_attributes(); set_options(options); y_ = graph.conv_wgrad(w_, x_, options); } graph.tensor_nchw(y_, 'y', y)->set_output(true); if (graph.prepare().is_bad()) { return std::nullopt; } graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS}); if (dtype == float32 && !env::enable_tf32()) { graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE}); } CHECK_CUDNN_FE_ERROR(graph.build()); return graph; } // Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups). array group_transpose( const array& x, int groups, int group_dim, int axis1, int axis2, Stream s) { if (groups == 1) { return swapaxes_in_eval(x, axis1, axis2); } int ndim = x.ndim(); if (group_dim < 0) { group_dim += ndim; } if (axis1 < 0) { axis1 += ndim; } if (axis2 < 0) { axis2 += ndim; } if (group_dim <= axis1) { axis1 += 1; } if (group_dim <= axis2) { axis2 += 1; } auto shape = x.shape(); shape.insert(shape.begin() + group_dim, groups); shape[group_dim + 1] = shape[group_dim + 1] / groups; array x_trans = reshape_in_eval(x, std::move(shape), s); x_trans = swapaxes_in_eval(x_trans, axis1, axis2); x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s); return x_trans; } // Do necessary transposes and copies to prepare the inputs and outputs for // building the cuDNN conv op. It is safe to be called multiple times in one // eval_gpu, with cost of possible redundant copies. std::tuple prepare_args( cu::CommandEncoder& encoder, ConvBackendType backend_type, array in, array wt, array out, int groups, Stream s) { // Transpose the args depending on the backend type. // TODO: Handle groups. if (backend_type == CONV_BACKWARD_INPUT) { wt = group_transpose(wt, groups, 0, 0, -1, s); } else if (backend_type == CONV_BACKWARD_WEIGHT) { in = group_transpose(in, groups, -1, 0, -1, s); wt = swapaxes_in_eval(wt, 0, -1); // Create a contiguous array that shares the data with |out|, but with dim // C_in and C_out swapped. Shape shape(out.shape()); std::swap(shape.front(), shape.back()); Strides strides(shape.size(), 1); for (int i = shape.size() - 2; i >= 0; --i) { strides[i] = shape[i + 1] * strides[i + 1]; } array intermediate(std::move(shape), out.dtype(), nullptr, {}); intermediate.copy_shared_buffer( out, std::move(strides), {true, true, false}, out.data_size()); out = intermediate; } // cuDNN requires contiguous input. if (!in.flags().row_contiguous) { in = contiguous_copy_gpu(in, s); encoder.add_temporary(in); } if (!wt.flags().row_contiguous) { wt = contiguous_copy_gpu(wt, s); encoder.add_temporary(wt); } return {std::move(in), std::move(wt), std::move(out)}; } // Register inputs and outputs before actually running conv op. Can only be // called once per eval_gpu. void register_args( cu::CommandEncoder& encoder, ConvBackendType backend_type, array& in, array& wt, array& intermediate_out, array& final_out) { encoder.set_input_array(in); encoder.set_input_array(wt); encoder.set_output_array(final_out); if (backend_type == CONV_BACKWARD_WEIGHT) { // Turn |out| into a strided array, which will have C_in and C_out swapped // in vjp and the final |grad_weight| will then be contiguous. Strides strides = intermediate_out.strides(); std::swap(strides.front(), strides.back()); final_out.copy_shared_buffer( intermediate_out, std::move(strides), {false, false, false}, intermediate_out.data_size()); } } } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out_) { nvtx3::scoped_range r("Convolution::eval_gpu"); if (out_.size() == 0) { return; } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 2); array in = inputs[0]; array wt = inputs[1]; array out = out_; out.set_data(cu::malloc_async(out.nbytes(), encoder)); Dtype dtype = out.dtype(); // Search cache. BytesKey cache_key; cache_key.pod.device_id = encoder.device().cuda_device(); cache_key.pod.cudnn_dtype = dtype_to_cudnn_type(dtype); cache_key.pod.input_shape = vector_key(in.shape()); cache_key.pod.weight_shape = vector_key(wt.shape()); cache_key.pod.stride = vector_key(kernel_strides_); cache_key.pod.padding_lo = vector_key(padding_lo_); cache_key.pod.padding_hi = vector_key(padding_hi_); cache_key.pod.dilation = vector_key(kernel_dilation_); cache_key.pod.groups = groups_; cache_key.pod.flip = flip_; cache_key.pod.input_alignment = get_alignment(in); cache_key.pod.weight_alignment = get_alignment(wt); cache_key.pod.output_alignment = get_alignment(out); if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { auto& [backend_type, graph] = it->second; if (graph) { // Run cached graph. std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); CHECK_CUDNN_FE_ERROR(graph->encode_capturing( encoder, { {'x', gpu_ptr(in)}, {'w', gpu_ptr(wt)}, {'y', gpu_ptr(out)}, })); } else { // Run fallback kernel. gemm_conv( encoder, in, wt, out, kernel_strides_, padding_lo_, kernel_dilation_, input_dilation_, groups_, flip_, s); } return; } // There is no reliable way to deduce the proper cuDNN backend for the // convolution, so we make a best guess and then try. SmallVector try_backends; if (flip_) { // When weight is flipped, we assume it is backward input convolution. try_backends.push_back(CONV_BACKWARD_INPUT); } else { // Otherwise it could be backward weight convolution or forward convolution, // mathematically there is no difference so we have to use heuristics. // Empirically backward convolutions have large kernel dimensions, and // usually have |in| and |wt| transposed. if (!in.flags().row_contiguous && !wt.flags().row_contiguous && wt.shape(2) > out.shape(2)) { try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD}; } else { try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT}; } } // Try to build op graph. ConvBackendType backend_type; std::optional graph; for (auto try_backend : try_backends) { auto [x, w, y] = prepare_args(encoder, try_backend, in, wt, out, groups_, s); auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings( try_backend, x, w, y, kernel_strides_, padding_lo_, padding_hi_, kernel_dilation_, input_dilation_); graph = build_conv_graph( encoder, try_backend, dtype, x, w, y, stride, padding_lo, padding_hi, dilation); if (graph) { backend_type = try_backend; in = std::move(x); wt = std::move(w); out = std::move(y); break; } } if (graph) { register_args(encoder, backend_type, in, wt, out, out_); CHECK_CUDNN_FE_ERROR(graph->encode_capturing( encoder, { {'x', gpu_ptr(in)}, {'w', gpu_ptr(wt)}, {'y', gpu_ptr(out)}, })); conv_cache().emplace( cache_key, std::make_pair(backend_type, std::move(*graph))); return; } // Use fallback kernel for settings not supported by cuDNN. gemm_conv( encoder, in, wt, out, kernel_strides_, padding_lo_, kernel_dilation_, input_dilation_, groups_, flip_, s); conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt)); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy/copy.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" namespace mlx::core { void copy_contiguous( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out); void copy_general( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in, const Strides& strides_out); void copy_general_dynamic( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in, const Strides& strides_out, const array& dynamic_offset_in, const array& dynamic_offset_out); void copy_general_input( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy/copy_contiguous.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/copy/copy.cuh" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = cast_to(in[0]); } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = cast_to(in[0]); } store_vector(out, index, out_vec); } } template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = cast_to(in[i]); } } else { auto in_vec = load_vector(in, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = cast_to(in_vec[i]); } store_vector(out, index, out_vec); } } } // namespace cu void copy_contiguous( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t in_offset, int64_t out_offset) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(in) + in_offset, gpu_ptr(out) + out_offset, out.data_size()); }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy/copy_general.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/copy/copy.cuh" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void copy_gg_nd( const In* in, Out* out, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides_in, const __grid_constant__ cuda::std::array strides_out) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto in_stride_x = strides_in[NDIM - 1]; auto out_stride_x = strides_out[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [idx_in, idx_out] = elem_to_loc_nd( index_rest * shape_x, shape.data(), strides_in.data(), strides_out.data()); auto in_vec = load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = CastOp{}(in_vec[i]); } store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } template __global__ void copy_gg( const In* in, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_out, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto in_stride_x = strides_in[ndim - 1]; auto out_stride_x = strides_out[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [idx_in, idx_out] = elem_to_loc( index_rest * shape_x, shape.data(), strides_in.data(), strides_out.data(), ndim); auto in_vec = load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = CastOp{}(in_vec[i]); } store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } } // namespace cu void copy_general( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in, const Strides& strides_out) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; const InType* in_ptr = gpu_ptr(in) + offset_in; OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); size_t data_size = 1; for (auto& s : shape) data_size *= s; int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = data_size / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { auto kernel = cu::copy_gg_nd; if (work_per_thread == 4) { kernel = cu::copy_gg_nd; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, in_ptr, out_ptr, rest, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 auto kernel = cu::copy_gg; if (work_per_thread == 4) { kernel = cu::copy_gg; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, in_ptr, out_ptr, rest, const_param(shape), const_param(strides_in), const_param(strides_out), ndim); } }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy/copy_general_dynamic.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/copy/copy.cuh" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void copy_gg_dynamic_nd( const In* in, Out* out, IdxT size, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides_in, const __grid_constant__ cuda::std::array strides_out, const int64_t* offset_in, const int64_t* offset_out) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [idx_in, idx_out] = elem_to_loc_nd( index, shape.data(), strides_in.data(), strides_out.data()); out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); } } template __global__ void copy_gg_dynamic( const In* in, Out* out, IdxT size, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_out, int ndim, const int64_t* offset_in, const int64_t* offset_out) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [idx_in, idx_out] = elem_to_loc( index, shape.data(), strides_in.data(), strides_out.data(), ndim); out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); } } } // namespace cu void copy_general_dynamic( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in, const Strides& strides_out, const array& dynamic_offset_in, const array& dynamic_offset_out) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; const InType* in_ptr = gpu_ptr(in) + offset_in; OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( cu::copy_gg_dynamic_nd< InType, OutType, IdxT, dims_constant()>, num_blocks, block_dims, in_ptr, out_ptr, out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), gpu_ptr(dynamic_offset_in), gpu_ptr(dynamic_offset_out)); }); } else { // ndim >= 4 auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( cu::copy_gg_dynamic, num_blocks, block_dims, in_ptr, out_ptr, out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), ndim, gpu_ptr(dynamic_offset_in), gpu_ptr(dynamic_offset_out)); } }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy/copy_general_input.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/copy/copy.cuh" #include namespace mlx::core { static constexpr int TILE_SIZE = 16; namespace cu { namespace cg = cooperative_groups; template __global__ void copy_g_nd( const In* in, Out* out, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto stride_x = strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto idx = elem_to_loc_nd(index_rest * shape_x, shape.data(), strides.data()); auto in_vec = load_vector(in + idx, index_x, shape_x, stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = CastOp{}(in_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template __global__ void copy_g( const In* in, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto stride_x = strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); auto in_vec = load_vector(in + idx, index_x, shape_x, stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = CastOp{}(in_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template __global__ void copy_col_row(const In* in, Out* out, int64_t rows, int64_t cols) { __shared__ Out tile[N_READS * TILE_SIZE][N_READS * TILE_SIZE + 4 / sizeof(Out)]; auto block = cg::this_thread_block(); auto grid = cg::this_grid(); auto tile_row = grid.block_index().x * TILE_SIZE * N_READS; auto tile_col = grid.block_index().y * TILE_SIZE * N_READS; auto tidx = block.thread_index().x; auto tidy = N_READS * block.thread_index().y; auto in_ptr = in + (tile_col + tidy) * rows + tile_row; #pragma unroll for (int i = 0; i < N_READS; ++i) { if ((tile_col + tidy + i) < cols) { auto in_vec = load_vector(in_ptr, tidx, rows - tile_row, In(0)); #pragma unroll for (int j = 0; j < N_READS; ++j) { tile[N_READS * tidx + j][tidy + i] = CastOp{}(in_vec[j]); } in_ptr += rows; } } block.sync(); auto out_ptr = out + (tile_row + tidy) * cols + tile_col; #pragma unroll for (int i = 0; i < N_READS; ++i) { if ((tile_row + tidy + i) < rows) { AlignedVector out_vec; #pragma unroll for (int j = 0; j < N_READS; ++j) { out_vec[j] = tile[tidy + i][N_READS * tidx + j]; } store_vector(out_ptr, tidx, out_vec, cols - tile_col); out_ptr += cols; } } } } // namespace cu void copy_general_input( cu::CommandEncoder& encoder, CopyType ctype, const array& in, array& out, int64_t offset_in, int64_t offset_out, const Shape& shape, const Strides& strides_in) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using InType = cuda_type_t; using OutType = cuda_type_t; const InType* in_ptr = gpu_ptr(in) + offset_in; OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); // Column contiguous to row contiguous specialization if (ndim == 2 && strides_in[0] == 1 && strides_in[1] == shape[0]) { constexpr int work_per_thread = std::min(static_cast(16 / sizeof(OutType)), 8); dim3 block_dims = {TILE_SIZE, TILE_SIZE}; uint32_t num_blocks_x = cuda::ceil_div(shape[0], TILE_SIZE * work_per_thread); uint32_t num_blocks_y = cuda::ceil_div(shape[1], TILE_SIZE * work_per_thread); auto kernel = cu::copy_col_row; encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, in_ptr, out_ptr, int64_t(shape[0]), int64_t(shape[1])); return; } dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; int work_per_thread = 8; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4 && dim0 < 8) { work_per_thread = 4; } else if (dim0 < 4) { work_per_thread = 1; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::copy_g_nd; if (work_per_thread == 8) { kernel = cu::copy_g_nd; } else if (work_per_thread == 4) { kernel = cu::copy_g_nd; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, in_ptr, out_ptr, rest, const_param(shape), const_param(strides_in)); }); } else { // ndim >= 4 auto kernel = cu::copy_g; if (work_per_thread == 8) { kernel = cu::copy_g; } else if (work_per_thread == 4) { kernel = cu::copy_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, in_ptr, out_ptr, rest, const_param(shape), const_param(strides_in), ndim); } }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/copy.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/copy/copy.cuh" namespace mlx::core { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { auto& encoder = cu::get_command_encoder(s); bool donated = set_copy_output_data( in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); }); if (donated && in.dtype() == out.dtype()) { // If the output has the same type as the input then there is nothing to // copy, just use the buffer. return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_gpu_inplace(in, out, ctype, s); } void copy_gpu_inplace( const array& in, array& out, const Shape& shape, const Strides& strides_in, const Strides& strides_out, int64_t offset_in, int64_t offset_out, CopyType ctype, const Stream& s, std::optional dynamic_offset_in, std::optional dynamic_offset_out) { if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); return; } if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( shape, std::vector{strides_in, strides_out}, INT32_MAX); if (ctype == CopyType::General) { copy_general_input( encoder, ctype, in, out, offset_in, offset_out, shape_collapsed, strides_vec[0]); } else { if (dynamic_offset_in || dynamic_offset_out) { if (!dynamic_offset_in) { dynamic_offset_in = array(0, int64); encoder.add_temporary(*dynamic_offset_in); } if (!dynamic_offset_out) { dynamic_offset_out = array(0, int64); encoder.add_temporary(*dynamic_offset_out); } encoder.set_input_array(*dynamic_offset_in); encoder.set_input_array(*dynamic_offset_out); copy_general_dynamic( encoder, ctype, in, out, offset_in, offset_out, shape_collapsed, strides_vec[0], strides_vec[1], *dynamic_offset_in, *dynamic_offset_out); } else { copy_general( encoder, ctype, in, out, offset_in, offset_out, shape_collapsed, strides_vec[0], strides_vec[1]); } } return; } } void fill_gpu(const array& in, array& out, const Stream& s) { if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_input_array(in); encoder.set_output_array(out); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); } void reshape_gpu(const array& in, array& out, Stream s) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); copy_gpu_inplace( in, out, in.shape(), in.strides(), make_contiguous_strides(in.shape()), 0, 0, CopyType::General, s); } else { shared_buffer_reshape(in, out_strides, out); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/cublas_utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/cuda.h" #include "mlx/utils.h" namespace mlx::core { namespace cublas_utils { namespace { struct CublasPreference { CublasPreference(cu::Device& device) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // for Hopper+: // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace uint64_t MiB = 1024 * 1024; uint64_t workspace_size = device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( pref_, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(uint64_t))); } ~CublasPreference() { CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); } cublasLtMatmulPreference_t pref_{nullptr}; }; } // namespace cublasLtMatmulPreference_t get_preference(cu::Device& device) { static CublasPreference pref(device); return pref.pref_; } cublasLtMatrixLayout_t create_matrix_layout( cudaDataType_t type, uint64_t rows, uint64_t cols, bool transposed, int64_t ld, int32_t batch_count, int64_t batch_stride) { cublasLtMatrixLayout_t desc; if (transposed) { std::swap(rows, cols); } CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); if (batch_count > 1) { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_stride, sizeof(int64_t))); } return desc; } } // namespace cublas_utils CublasMatmulBase::~CublasMatmulBase() { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } void CublasMatmulBase::init_base( cu::Device& device, cudaDataType_t scale_type, cublasComputeType_t compute_type, cudaDataType_t data_type, cudaDataType_t output_type, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride) { M_ = a_rows; N_ = b_cols; scale_type_ = scale_type; handle_ = device.get_cublaslt_handle(); pref_ = cublas_utils::get_preference(device); heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; CHECK_CUBLAS_ERROR( cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); // In cublasLt matrices use column-major layout, while it is possible to use // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias // epilogue does not work with the option. So instead we swap A and B to make // cublasLt return the row-major result, which works because: // - the data of a matrix in row-major layout is identical to its transpose in // column-major layout // - C^T = (A @ B)^T = B^T @ A^T cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t))); cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t))); a_desc_ = cublas_utils::create_matrix_layout( data_type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); b_desc_ = cublas_utils::create_matrix_layout( data_type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); out_desc_ = cublas_utils::create_matrix_layout( output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows); } void CublasMatmulBase::execute_matmul( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* c, const void* alpha_ptr, const void* beta_ptr) { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { int ret = 0; CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( handle_, matmul_desc_, a_desc_, b_desc_, c ? c_desc_ : out_desc_, out_desc_, pref_, 1, &heuristic_, &ret)); if (ret == 0) { throw std::runtime_error("Can not find algorithm for matmul."); } } void* workspace_ptr = allocate_workspace(encoder, heuristic_.workspaceSize); // Execute matmul auto capture = encoder.capture_context(); CHECK_CUBLAS_ERROR(cublasLtMatmul( handle_, matmul_desc_, alpha_ptr, b, // a and b are swapped for row-major layout a_desc_, a, b_desc_, beta_ptr, c ? c : out, c ? c_desc_ : out_desc_, out, out_desc_, &heuristic_.algo, workspace_ptr, heuristic_.workspaceSize, encoder.stream())); } void CublasMatmulBase::set_bias( cu::CommandEncoder& encoder, const array& bias) { encoder.set_input_array(bias); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); auto* bias_ptr = gpu_ptr(bias); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/cublas_utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" namespace mlx::core { namespace cublas_utils { // Get the shared cublas preference for a device cublasLtMatmulPreference_t get_preference(cu::Device& device); cublasLtMatrixLayout_t create_matrix_layout( cudaDataType_t type, uint64_t rows, uint64_t cols, bool transposed, int64_t ld, int32_t batch_count, int64_t batch_stride); inline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) { switch (dtype) { case float16: return CUDA_R_16F; case bfloat16: return CUDA_R_16BF; case float32: return CUDA_R_32F; case float64: return CUDA_R_64F; case complex64: return CUDA_C_32F; default: throw std::runtime_error( fmt::format( "Unsupported dtype in {}: {}.", tag, dtype_to_string(dtype))); } } } // namespace cublas_utils class CublasMatmulBase { public: virtual ~CublasMatmulBase(); void set_bias(cu::CommandEncoder& encoder, const array& bias); protected: CublasMatmulBase() = default; // Common member variables shared by all matmul types uint64_t M_; uint64_t N_; cudaDataType_t scale_type_; cublasLtMatmulPreference_t pref_{nullptr}; cublasLtHandle_t handle_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t b_desc_{nullptr}; cublasLtMatrixLayout_t c_desc_{nullptr}; cublasLtMatrixLayout_t out_desc_{nullptr}; cublasLtMatmulHeuristicResult_t heuristic_; void init_base( cu::Device& device, cudaDataType_t scale_type, cublasComputeType_t compute_type, cudaDataType_t data_type, cudaDataType_t output_type, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride); void execute_matmul( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* c, const void* alpha_ptr, const void* beta_ptr); }; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/cuda.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include "mlx/api.h" namespace mlx::core::cu { /* Check if the CUDA backend is available. */ MLX_API bool is_available(); /* Get information about a CUDA device. */ MLX_API const std::unordered_map>& device_info(int device_index = 0); } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/cuda_utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include namespace mlx::core { // Throw exception if the cuda API does not succeed. void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, CUresult err); void check_cudnn_error(const char* name, cudnnStatus_t err); // The macro version that prints the command that failed. #define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) // Base class for RAII managed CUDA resources. template class CudaHandle { public: CudaHandle(Handle handle = nullptr) : handle_(handle) {} CudaHandle(CudaHandle&& other) : handle_(other.handle_) { assert(this != &other); other.handle_ = nullptr; } ~CudaHandle() { // Skip if there was an error to avoid throwing in the destructors if (cudaPeekAtLastError() != cudaSuccess) { return; } reset(); } CudaHandle(const CudaHandle&) = delete; CudaHandle& operator=(const CudaHandle&) = delete; CudaHandle& operator=(CudaHandle&& other) { assert(this != &other); reset(); std::swap(handle_, other.handle_); return *this; } void reset() { if (handle_ != nullptr) { CHECK_CUDA_ERROR(Destroy(handle_)); handle_ = nullptr; } } operator Handle() const { return handle_; } protected: Handle handle_; }; namespace cu { class Device; }; // namespace cu // Wrappers of CUDA resources. class CudaGraph : public CudaHandle { public: using CudaHandle::CudaHandle; explicit CudaGraph(cu::Device& device); void end_capture(cudaStream_t stream); }; class CudaGraphExec : public CudaHandle { public: void instantiate(cudaGraph_t graph); }; class CudaStream : public CudaHandle { public: using CudaHandle::CudaHandle; explicit CudaStream(cu::Device& device); }; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/cudnn_utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" namespace mlx::core { namespace { #define RETURN_IF_ERROR(cmd) \ if (auto ret = cmd; ret.is_bad()) { \ return ret; \ } // In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN // whether a tensor is contiguous is determined with: // shape[dim] == shape[dim + 1] * strides[dim + 1] // So a contiguous array with singleton dims in MLX may be mistakenly treated // as strided in cuDNN, and we work around it by normalizing the strides. std::vector normalized_strides(const array& x) { std::vector strides(x.strides().begin(), x.strides().end()); if (std::all_of( strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) { strides.back() = 1; return strides; } if (!x.flags().row_contiguous || x.ndim() < 2) { return strides; } for (int i = x.ndim() - 2; i >= 0; --i) { if (x.shape(i) == 1) { strides[i] = x.shape(i + 1) * strides[i + 1]; } } return strides; } // Return the shape and strides after transposing from NHWC to NCHW. inline auto nhwc_to_nchw(const array& x) { auto shape = convert_vector(x.shape()); auto strides = normalized_strides(x); assert(shape.size() >= 3); shape.insert(shape.begin() + 1, shape.back()); shape.erase(shape.end() - 1); strides.insert(strides.begin() + 1, strides.back()); strides.erase(strides.end() - 1); return std::make_tuple(std::move(shape), std::move(strides)); } } // namespace fe::error_t DnnGraph::prepare() { RETURN_IF_ERROR(validate()); try { RETURN_IF_ERROR(build_operation_graph(handle_)); } catch (cudnn_frontend::cudnnException& error) { // cuDNN bug: they did not catch all exceptions in the API. return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()}; } RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A})); return {}; } fe::error_t DnnGraph::build() { RETURN_IF_ERROR(check_support(handle_)); RETURN_IF_ERROR(build_plans(handle_)); return {}; } fe::error_t DnnGraph::encode_graph( cu::CommandEncoder& encoder, std::unordered_map variant_pack) { cudnnSetStream(handle_, encoder.stream()); auto* workspace_ptr = prepare_workspace(encoder); if (!cached_cuda_graph_) { // First call: populate the CUDA graph from the cuDNN execution plan. // Also compute and cache the subgraph key to avoid calling // cudaGraphKernelNodeGetAttribute on every subsequent call (expensive // on WDDM where each driver API call has ~40-400us overhead). cached_cuda_graph_.emplace(encoder.device()); RETURN_IF_ERROR(populate_cuda_graph( handle_, variant_pack, workspace_ptr, *cached_cuda_graph_)); std::tie(cached_subgraph_key_, cached_is_updatable_) = cu::subgraph_to_key(*cached_cuda_graph_); } else { // Subsequent calls: patch data pointers without re-running kernel setup. RETURN_IF_ERROR(update_cuda_graph( handle_, variant_pack, workspace_ptr, *cached_cuda_graph_)); } // Add the cuDNN child graph to the parent CUDA graph for batched launch. // The pre-computed subgraph key avoids expensive per-node attribute queries. encoder.add_graph_node( *cached_cuda_graph_, cached_subgraph_key_, cached_is_updatable_); return {}; } fe::error_t DnnGraph::encode_capturing( cu::CommandEncoder& encoder, std::unordered_map variant_pack) { auto* workspace_ptr = prepare_workspace(encoder); auto capture = encoder.capture_context(); cudnnSetStream(handle_, encoder.stream()); auto ret = execute(handle_, variant_pack, workspace_ptr); if (ret.is_bad()) { capture.discard = true; } return ret; } void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) { int64_t workspace_size = 0; CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size)); return allocate_workspace(encoder, workspace_size); } void DnnGraph::set_tensor_attrs( std::shared_ptr& tensor, int64_t uid, const array& x, const std::vector& shape, const std::vector& strides) { tensor->set_uid(uid) .set_alignment(get_alignment(x)) .set_data_type(dtype_to_cudnn_type(x.dtype())) .set_dim(shape) .set_stride(strides); } void DnnGraph::set_tensor_attrs( std::shared_ptr& tensor, int64_t uid, const array& x) { set_tensor_attrs( tensor, uid, x, convert_vector(x.shape()), normalized_strides(x)); } void DnnGraph::set_tensor_attrs_nchw( std::shared_ptr& tensor, int64_t uid, const array& x) { auto [shape, strides] = nhwc_to_nchw(x); set_tensor_attrs(tensor, uid, x, shape, strides); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/cudnn_utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include "mlx/backend/cuda/cuda_utils.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/utils.h" #include "mlx/dtype_utils.h" #include #include namespace mlx::core { namespace cu { class CommandEncoder; } namespace fe = cudnn_frontend; #define CHECK_CUDNN_FE_ERROR(cmd) \ do { \ auto error = cmd; \ if (!error.is_good()) { \ throw std::runtime_error( \ fmt::format("{} failed: {}.", #cmd, error.get_message())); \ } \ } while (0) // Return pointer alignment of |x|'s data. inline uint8_t get_alignment(const array& x) { uint8_t alignment = 1; uintptr_t address = reinterpret_cast(gpu_ptr(x)); for (; alignment < 32; alignment *= 2) { if (address % (alignment * 2)) { return alignment; } } return alignment; } // Convert the type of elements in |vec| to |T|. template inline std::vector convert_vector(const Vec& vec) { return std::vector(vec.begin(), vec.end()); } // Map dtype to cudnn data type. inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { switch (dtype) { case int8: return fe::DataType_t::INT8; case int32: return fe::DataType_t::INT32; case uint8: return fe::DataType_t::UINT8; case float16: return fe::DataType_t::HALF; case bfloat16: return fe::DataType_t::BFLOAT16; case float32: return fe::DataType_t::FLOAT; case float64: return fe::DataType_t::DOUBLE; default: throw std::runtime_error( fmt::format( "Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype))); } } // Return an array that can be used as map key for |vec| with size <= MAX_NDIM. // // There are 2 differences from the const_param util from kernel_utils.cuh: // 1. The rest of array is filled with 0. // 2. This util can be used in .cpp files. template inline std::array vector_key(const Vec& vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } std::array result = {}; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } // Extends cuDNN graph with helpers. class DnnGraph : public fe::graph::Graph { public: DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32) : handle_(handle) { set_io_data_type(dtype_to_cudnn_type(io_dtype)); set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype)); set_compute_data_type(dtype_to_cudnn_type(compute_dtype)); } // Create a cuDNN tensor description from MLX array |x|. auto& tensor( std::shared_ptr& attrs, int64_t uid, const array& x) { set_tensor_attrs(attrs, uid, x); return attrs; } auto tensor(const char* name, int64_t uid, const array& x) { auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); tensor(attrs, uid, x); return attrs; } // Create a cuDNN tensor description from MLX array |x|, and transpose it from // NHWC layout to NCHW. auto& tensor_nchw( std::shared_ptr& attrs, int64_t uid, const array& x) { set_tensor_attrs_nchw(attrs, uid, x); return attrs; } auto tensor_nchw(const char* name, int64_t uid, const array& x) { auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); tensor_nchw(attrs, uid, x); return attrs; } // Create a 4D cuDNN tensor from 1D array, with |axis| being contiguous dim. auto tensor_4d(const char* name, int64_t uid, const array& x, int axis) { assert(x.ndim() == 1); auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); std::vector shape(4, 1); std::vector strides(4, 1); shape.at(axis) = x.size(); if (axis > 0) { strides.at(axis - 1) = x.size(); } set_tensor_attrs(attrs, uid, x, shape, strides); return attrs; } // Create a cuDNN tensor for scalar. auto scalar(const char* name, int64_t uid, Dtype dtype) { return Graph::tensor( fe::graph::Tensor_attributes() .set_name(name) .set_uid(uid) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(dtype_to_cudnn_type(dtype))); } // Call this before setting notes. fe::error_t prepare(); // Call this after setting notes. fe::error_t build(); // Add cuDNN graph to CUDA graph, using native CUDA graph API. fe::error_t encode_graph( cu::CommandEncoder& encoder, std::unordered_map variant_pack); // Add cuDNN graph to CUDA graph, using stream capture. fe::error_t encode_capturing( cu::CommandEncoder& encoder, std::unordered_map variant_pack); private: void* prepare_workspace(cu::CommandEncoder& encoder); void set_tensor_attrs( std::shared_ptr& tensor, int64_t uid, const array& x, const std::vector& shape, const std::vector& strides); void set_tensor_attrs( std::shared_ptr& tensor, int64_t uid, const array& x); void set_tensor_attrs_nchw( std::shared_ptr& tensor, int64_t uid, const array& x); cudnnHandle_t handle_; std::optional cached_cuda_graph_; std::string cached_subgraph_key_; bool cached_is_updatable_{true}; }; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/custom_kernel.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include #include namespace mlx::core::fast { namespace { constexpr const char* default_header = R"( #include "mlx/backend/cuda/device/utils.cuh" #include #define inf cuda::std::numeric_limits::infinity() )"; std::string template_arguments_hash( const std::vector>& template_args) { if (template_args.empty()) { return ""; } std::string hash; hash.reserve(512); for (const auto& [name, arg] : template_args) { if (std::holds_alternative(arg)) { hash += fmt::format("_{}", std::get(arg)); } else if (std::holds_alternative(arg)) { hash += (std::get(arg)) ? "_t" : "_f"; } else if (std::holds_alternative(arg)) { hash += "_"; hash += get_type_string(std::get(arg)); } } return hash; } std::string build_kernel( const std::string& func_name, const std::string& header, const std::string& source, const std::vector& input_names, const std::vector& inputs, const std::vector& output_names, const std::vector& output_dtypes, const std::vector>& template_args, const std::vector>& shape_infos) { std::string kernel_source; kernel_source.reserve(header.size() + source.size() + 8192); kernel_source += default_header; kernel_source += header; kernel_source += "namespace mlx::core::cu {\n\n" "namespace cg = cooperative_groups;\n\n"; kernel_source += "__global__ void "; kernel_source += func_name; kernel_source += "(\n"; // Add inputs for (int i = 0; i < inputs.size(); ++i) { const auto& name = input_names[i]; const auto& arr = inputs[i]; kernel_source += " const "; kernel_source += dtype_to_cuda_type(arr.dtype()); kernel_source += "* "; kernel_source += name; kernel_source += ",\n"; // Add input shape, strides and ndim if present in the source if (arr.ndim() > 0) { if (std::get<0>(shape_infos[i])) { kernel_source += " const __grid_constant__ Shape "; kernel_source += name; kernel_source += "_shape,\n"; } if (std::get<1>(shape_infos[i])) { kernel_source += " const __grid_constant__ Strides "; kernel_source += name; kernel_source += "_strides,\n"; } if (std::get<2>(shape_infos[i])) { kernel_source += " const __grid_constant__ int "; kernel_source += name; kernel_source += "_ndim,\n"; } } } // Add outputs for (int i = 0; i < output_names.size(); ++i) { const auto& name = output_names[i]; const auto& dtype = output_dtypes[i]; kernel_source += " "; kernel_source += dtype_to_cuda_type(dtype); kernel_source += "* "; kernel_source += name; if (i < output_names.size() - 1) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; } } // Set compile time constants if (!template_args.empty()) { for (const auto& [name, arg] : template_args) { if (std::holds_alternative(arg)) { kernel_source += fmt::format(" constexpr int {} = {};\n", name, std::get(arg)); } else if (std::holds_alternative(arg)) { kernel_source += fmt::format( " constexpr bool {} = {};\n", name, std::get(arg)); } else { kernel_source += fmt::format( " using {} = {};\n", name, dtype_to_cuda_type(std::get(arg))); } } kernel_source += "\n"; } kernel_source += source; kernel_source += "\n}\n\n} // namespace mlx::core::cu\n"; return kernel_source; } } // namespace CustomKernelFunction cuda_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, const std::string& source, const std::string& header, bool ensure_row_contiguous, int shared_memory) { if (output_names.empty()) { throw std::invalid_argument( "[custom_kernel] Must specify at least one output."); } std::vector> shape_infos; for (auto& n : input_names) { std::tuple shape_info; std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos; std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos; std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos; shape_infos.push_back(shape_info); } return [=, shape_infos = std::move(shape_infos)]( const std::vector& inputs, const std::vector& output_shapes, const std::vector& output_dtypes, std::tuple grid, std::tuple threadgroup, const std::vector>& template_args = {}, std::optional init_value = std::nullopt, bool verbose = false, StreamOrDevice s_ = {}) { if (inputs.size() != input_names.size()) { std::ostringstream msg; msg << "[custom_kernel] Expected `inputs` to have size " << input_names.size() << " but got size " << inputs.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_shapes.size() != output_names.size()) { std::ostringstream msg; msg << "[custom_kernel] Expected `output_shapes` to have size " << output_names.size() << " but got size " << output_shapes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_dtypes.size() != output_names.size()) { std::ostringstream msg; msg << "[custom_kernel] Expected `output_dtypes` to have size " << output_names.size() << " but got size " << output_dtypes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } auto s = to_stream(s_); if (s.device != Device::gpu) { throw std::invalid_argument("[custom_kernel] Only supports the GPU."); } std::string kernel_name = "custom_kernel_" + name + template_arguments_hash(template_args); std::string kernel_source = build_kernel( kernel_name, header, source, input_names, inputs, output_names, output_dtypes, template_args, shape_infos); if (verbose) { std::cout << "Generated source code for `" << kernel_name << "`:" << std::endl << "```" << std::endl << kernel_source << std::endl << "```" << std::endl; } return array::make_arrays( std::move(output_shapes), std::move(output_dtypes), std::make_shared( s, std::move(kernel_name), std::move(kernel_source), grid, threadgroup, shape_infos, ensure_row_contiguous, init_value, std::vector{}, false, shared_memory), std::move(inputs)); }; } std::vector precompiled_cuda_kernel( const std::string& name, const std::string& compiled_source, const std::vector& inputs, const std::vector& output_shapes, const std::vector& output_dtypes, const std::vector& scalars, std::tuple grid, std::tuple threadgroup, int shared_memory, std::optional init_value, bool ensure_row_contiguous, StreamOrDevice s) { std::vector> shape_infos( inputs.size(), {false, false, false}); return array::make_arrays( output_shapes, output_dtypes, std::make_shared( to_stream(s), name, compiled_source, grid, threadgroup, shape_infos, ensure_row_contiguous, init_value, scalars, true, shared_memory), inputs); } void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("CustomKernel::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); std::vector copies; // Allocate and initialize the output arrays for (auto& out : outputs) { if (init_value_) { copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { out.set_data(cu::malloc_async(out.nbytes(), encoder)); } } // Create the input arrays and copy if needed auto check_input = [&copies, &s, this](const array& x) -> const array { bool no_copy = x.flags().row_contiguous; if (!ensure_row_contiguous_ || no_copy) { return x; } else { copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); copy_gpu(x, copies.back(), CopyType::General, s); return copies.back(); } }; std::vector checked_inputs; for (const array& in : inputs) { checked_inputs.push_back(check_input(in)); } // Compile the custom kernel std::string kernel_name = (is_precompiled_) ? name_ : "mlx::core::cu::" + name_; cu::JitModule& mod = cu::get_jit_module( s.device, name_, [&]() { return std::make_tuple( is_precompiled_, source_, std::vector{kernel_name}); }, false); // Make the arguments cu::KernelArgs args; for (int i = 0; i < checked_inputs.size(); i++) { const array& in = checked_inputs[i]; auto& shape_info = shape_infos_[i]; args.append(in); if (std::get<0>(shape_info)) { args.append_ndim(in.shape()); } if (std::get<1>(shape_info)) { args.append_ndim(in.strides()); } if (std::get<2>(shape_info)) { args.append(in.ndim()); } } for (auto& out : outputs) { args.append(out); } for (auto& s : scalar_arguments_) { if (std::holds_alternative(s)) { args.append(std::get(s)); } else if (std::holds_alternative(s)) { args.append(std::get(s)); } else if (std::holds_alternative(s)) { args.append(std::get(s)); } } // Make the grid const auto [tx, ty, tz] = threadgroup_; const auto [gx, gy, gz] = grid_; dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); // Call the kernel for (const auto& in : checked_inputs) { encoder.set_input_array(in); } for (const auto& out : outputs) { encoder.set_output_array(out); } for (const auto& t : copies) { encoder.add_temporary(t); } auto kernel = mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) { if (smem > 0 && smem > 48000) { cuFuncSetAttribute( kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem); } }); encoder.add_kernel_node_raw( kernel, grid, block, {}, shared_memory_, args.args()); } } // namespace mlx::core::fast ================================================ FILE: mlx/backend/cuda/cutlass_utils.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/dtype.h" #include #include #include namespace mlx::core { // Throw exception if the cutlass API does not succeed. inline void check_cutlass_error(const char* name, cutlass::Status status) { if (status != cutlass::Status::kSuccess) { throw std::runtime_error( fmt::format( "{} failed with code: {}.", name, cutlass::cutlassGetStatusString(status))); } } // The macro version that prints the command that failed. #define CHECK_CUTLASS_ERROR(cmd) ::mlx::core::check_cutlass_error(#cmd, (cmd)) // Maps CPU types to CUTLASS types. template struct CTypeToCutlassType { using type = T; }; template <> struct CTypeToCutlassType { using type = cutlass::half_t; }; template <> struct CTypeToCutlassType { using type = cutlass::bfloat16_t; }; template using cutlass_type_t = typename CTypeToCutlassType::type; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/delayload.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/common/utils.h" // clang-format off #include // must be included first #include // clang-format on namespace mlx::core { namespace fs = std::filesystem; inline fs::path relative_to_current_binary(const char* relative) { return fs::absolute(current_binary_dir() / relative); } inline fs::path cublas_bin_dir() { #if defined(MLX_CUDA_BIN_DIR) return MLX_CUDA_BIN_DIR; #else return relative_to_current_binary("../nvidia/cublas/bin"); #endif } fs::path load_nvrtc() { #if defined(MLX_CUDA_BIN_DIR) fs::path nvrtc_bin_dir = MLX_CUDA_BIN_DIR; #else fs::path nvrtc_bin_dir = relative_to_current_binary("../nvidia/cuda_nvrtc/bin"); #endif // Internally nvrtc loads some libs dynamically, add to search dirs. ::AddDllDirectory(nvrtc_bin_dir.c_str()); return nvrtc_bin_dir; } fs::path load_cudnn() { #if defined(MLX_CUDNN_BIN_DIR) fs::path cudnn_bin_dir = MLX_CUDNN_BIN_DIR; #else fs::path cudnn_bin_dir = relative_to_current_binary("../nvidia/cudnn/bin"); #endif // Must load cudnn_graph64_9.dll before locating symbols, otherwise We would // get errors like "Invalid handle. Cannot load symbol cudnnCreate". for (const auto& dll : fs::directory_iterator(cudnn_bin_dir)) { if (dll.path().filename().string().starts_with("cudnn_graph") && dll.path().extension() == ".dll") { ::LoadLibraryW(dll.path().c_str()); break; } } // Internally cuDNN loads some libs dynamically, add to search dirs. load_nvrtc(); ::AddDllDirectory(cudnn_bin_dir.c_str()); ::AddDllDirectory(cublas_bin_dir().c_str()); return cudnn_bin_dir; } // Called by system when failed to locate a lazy-loaded DLL. FARPROC WINAPI delayload_helper(unsigned dliNotify, PDelayLoadInfo pdli) { HMODULE mod = NULL; if (dliNotify == dliNotePreLoadLibrary) { std::string dll = pdli->szDll; if (dll.starts_with("cudnn")) { static auto cudnn_bin_dir = load_cudnn(); mod = ::LoadLibraryW((cudnn_bin_dir / dll).c_str()); } else if (dll.starts_with("cublas")) { mod = ::LoadLibraryW((cublas_bin_dir() / dll).c_str()); } else if (dll.starts_with("nvrtc")) { static auto nvrtc_bin_dir = load_nvrtc(); mod = ::LoadLibraryW((nvrtc_bin_dir / dll).c_str()); } } return reinterpret_cast(mod); } } // namespace mlx::core extern "C" const PfnDliHook __pfnDliNotifyHook2 = mlx::core::delayload_helper; ================================================ FILE: mlx/backend/cuda/device/atomic_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include namespace mlx::core::cu { template inline __device__ void atomic_add(T* out, T val) { cuda::atomic_ref ref(*out); ref += val; } template inline __device__ void atomic_prod(T* out, T val) { cuda::atomic_ref ref(*out); T old = ref.load(); while (!ref.compare_exchange_strong(old, old * val)) { } } template inline __device__ void atomic_max(T* out, T val) { cuda::atomic_ref ref(*out); ref.fetch_max(val); } template inline __device__ void atomic_min(T* out, T val) { cuda::atomic_ref ref(*out); ref.fetch_min(val); } // Somehow cuda::atomic_ref does not provide atomic add for following types. template inline __device__ void atomic_add_general(T* out, T val) { cuda::atomic_ref ref(*out); T old = ref.load(); while (!ref.compare_exchange_strong(old, old + val)) { } } inline __device__ void atomic_add(__half* out, __half val) { atomicAdd(out, val); } inline __device__ void atomic_add(complex64_t* out, complex64_t val) { atomic_add_general(out, val); } inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { #if __CUDA_ARCH__ < 800 atomic_add_general(out, val); #else atomicAdd(out, val); #endif } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/binary_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/unary_ops.cuh" #include namespace mlx::core::cu { struct Add { template __device__ T operator()(T x, T y) { return x + y; } }; struct FloorDivide { template __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return x / y; } else { return cuda::std::trunc(x / y); } } }; struct Divide { template __device__ T operator()(T x, T y) { return x / y; } }; struct Remainder { template __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { if constexpr (cuda::std::is_signed_v) { auto r = x % y; if (r != 0 && (r < 0 != y < 0)) { r += y; } return r; } else { return x % y; } } else if constexpr (is_complex_v) { return x % y; } else { T r = cuda::std::fmod(x, y); if (r != 0 && (r < 0 != y < 0)) { r = r + y; } return r; } } }; struct Equal { template __device__ bool operator()(T x, T y) { return x == y; } }; struct NaNEqual { template __device__ bool operator()(T x, T y) { using cuda::std::isnan; if constexpr (is_complex_v) { return x == y || (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) && isnan(y.imag())) || (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) || (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag()); } else { return x == y || (isnan(x) && isnan(y)); } } }; struct Greater { template __device__ bool operator()(T x, T y) { return x > y; } }; struct GreaterEqual { template __device__ bool operator()(T x, T y) { return x >= y; } }; struct Less { template __device__ bool operator()(T x, T y) { return x < y; } }; struct LessEqual { template __device__ bool operator()(T x, T y) { return x <= y; } }; struct LogAddExp { template __device__ T operator()(T x, T y) { if constexpr (is_complex_v) { if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag()) || cuda::std::isnan(y.real()) || cuda::std::isnan(y.imag())) { return { cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; } auto max = x.real() > y.real() ? x : y; auto min = x.real() < y.real() ? x : y; auto min_real = min.real(); auto max_real = max.real(); if (!cuda::std::isfinite(min_real) && (min_real == max_real)) { if (min_real < 0) { return min; } else { return Log{}(Exp{}(min) + Exp{}(max)); } } else { return Log1p{}(Exp{}(min - max)) + max; } } else { if (cuda::std::isnan(x) || cuda::std::isnan(y)) { return cuda::std::numeric_limits::quiet_NaN(); } T maxval = max(x, y); T minval = min(x, y); return (minval == -cuda::std::numeric_limits::infinity() || maxval == cuda::std::numeric_limits::infinity()) ? maxval : T(maxval + cuda::std::log1p(cuda::std::exp(minval - maxval))); } }; }; struct Maximum { template __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return max(x, y); } else if constexpr (is_complex_v) { if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag())) { return x; } return x > y ? x : y; } else { if (cuda::std::isnan(x)) { return x; } return x > y ? x : y; } } }; struct Minimum { template __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return min(x, y); } else if constexpr (is_complex_v) { if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag())) { return x; } return x < y ? x : y; } else { if (cuda::std::isnan(x)) { return x; } return x < y ? x : y; } } }; struct Multiply { template __device__ T operator()(T x, T y) { return x * y; } }; struct NotEqual { template __device__ bool operator()(T x, T y) { if constexpr (is_complex_v) { return x.real() != y.real() || x.imag() != y.imag(); } else { return x != y; } } }; struct Power { template __device__ T operator()(T base, T exp) { if constexpr (cuda::std::is_integral_v) { T res = 1; // Raising an integer to a negative power is undefined if constexpr (cuda::std::is_signed_v) { if (exp < 0) { return 0; } } while (exp) { if (exp & 1) { res *= base; } exp >>= 1; base *= base; } return res; } else if constexpr (is_complex_v) { return cuda::std::pow(base, exp); } else { return cuda::std::pow(base, exp); } } }; struct Subtract { template __device__ T operator()(T x, T y) { return x - y; } }; struct LogicalAnd { template __device__ T operator()(T x, T y) { return x && y; }; }; struct LogicalOr { template __device__ T operator()(T x, T y) { return x || y; }; }; struct BitwiseAnd { template __device__ T operator()(T x, T y) { return x & y; }; }; struct BitwiseOr { template __device__ T operator()(T x, T y) { return x | y; }; }; struct BitwiseXor { template __device__ T operator()(T x, T y) { return x ^ y; }; }; struct LeftShift { template __device__ T operator()(T x, T y) { return x << y; }; }; struct RightShift { template __device__ T operator()(T x, T y) { return x >> y; }; }; struct ArcTan2 { template __device__ T operator()(T y, T x) { return cuda::std::atan2(y, x); } }; struct DivMod { template __device__ cuda::std::array operator()(T x, T y) { return {FloorDivide{}(x, y), Remainder{}(x, y)}; }; }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/cast_op.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/complex.cuh" #include #include namespace mlx::core::cu { // An op that does static_cast, with custom conversions for some types. template struct CastOp { static constexpr bool is_castable = cuda::std::is_convertible_v; __device__ DstT operator()(SrcT x) { return static_cast(x); } }; // Castings between complex and boolean. template struct CastOp, bool> { static constexpr bool is_castable = true; __device__ bool operator()(complex_t x) { return x.real() != 0 && x.imag() != 0; } }; template struct CastOp> { static constexpr bool is_castable = true; __device__ complex_t operator()(bool x) { return x ? complex_t{1, 1} : complex_t{0, 0}; } }; // Converting a complex number to real number discards the imaginary part. template struct CastOp, DstT, cuda::std::enable_if_t>> { static constexpr bool is_castable = cuda::std::is_convertible_v; __device__ DstT operator()(complex_t x) { static_assert(!is_complex_v); return static_cast(x.real()); } }; // Allow converting a real number to complex number. template struct CastOp, cuda::std::enable_if_t>> { static constexpr bool is_castable = cuda::std::is_convertible_v; __device__ complex_t operator()(SrcT x) { static_assert(!is_complex_v); return complex_t{static_cast(x), 0}; } }; // Do nothing when no casting is needed. template struct CastOp< SrcT, DstT, cuda::std::enable_if_t>> { static constexpr bool is_castable = true; __device__ SrcT operator()(SrcT x) { return x; } }; // In CUDA 11 the half types do not define conversions between some types, // provide fallbacks here. #if CUDART_VERSION < 12000 template struct CastOp< SrcT, DstT, cuda::std::enable_if_t< !cuda::std::is_convertible_v && !is_complex_v && (cuda::std::is_same_v || cuda::std::is_same_v)>> { static constexpr bool is_castable = true; __device__ DstT operator()(SrcT x) { return DstT(static_cast(x)); } }; template struct CastOp< SrcT, DstT, cuda::std::enable_if_t< !cuda::std::is_convertible_v && !is_complex_v && !cuda::std::is_same_v && !cuda::std::is_same_v && (cuda::std::is_same_v || cuda::std::is_same_v)>> { static constexpr bool is_castable = true; __device__ DstT operator()(SrcT x) { return DstT(static_cast(x)); } }; #endif // CUDART_VERSION < 12000 // Helper to deduce the SrcT. template inline __host__ __device__ auto cast_to(SrcT x) { return CastOp{}(x); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/complex.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once // Make multiplication and division faster. #define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS #include #include namespace mlx::core::cu { // TODO: Consider using a faster implementation as cuda::std::complex has to // conform to C++ standard. template using complex_t = cuda::std::complex; using complex64_t = complex_t; using complex128_t = complex_t; template struct is_complex : cuda::std::false_type {}; template struct is_complex> : cuda::std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; // cuda::std::complex is missing some operators. template inline __host__ __device__ complex_t operator%( complex_t a, complex_t b) { T r = a.real() - floor(a.real() / b.real()) * b.real(); T i = a.imag() - floor(a.imag() / b.imag()) * b.imag(); return complex_t{r, i}; } template inline __host__ __device__ bool operator>(complex_t a, complex_t b) { return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } template inline __host__ __device__ bool operator<(complex_t a, complex_t b) { return operator>(b, a); } template inline __host__ __device__ bool operator<=(complex_t a, complex_t b) { return !(a > b); } template inline __host__ __device__ bool operator>=(complex_t a, complex_t b) { return !(a < b); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/config.h ================================================ // Copyright © 2025 Apple Inc. // This file is used by both CUDA kernel code and host-only C++ code. #pragma once // The maximum dimensions of shape/strides passed as kernel parameters. #define MAX_NDIM 10 // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in // warpSize variable exists, using it would prevent compile-time optimizations. #define WARP_SIZE 32 ================================================ FILE: mlx/backend/cuda/device/fp16_math.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include namespace mlx::core::cu { /////////////////////////////////////////////////////////////////////////////// // Binary ops for half types. /////////////////////////////////////////////////////////////////////////////// #define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ template \ __forceinline__ __device__ auto NAME(T x, T y) { \ if constexpr (cuda::std::is_same_v) { \ return HALF_OP(x, y); \ } else if constexpr (cuda::std::is_same_v) { \ return HALF_OP(x, y); \ } else { \ return ::NAME(x, y); \ } \ } MLX_DEFINE_BINARY_OP(max, __hmax) MLX_DEFINE_BINARY_OP(min, __hmin) #undef MLX_DEFINE_BINARY_OP /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// template constexpr bool is_integral_except = cuda::std::is_integral_v && !cuda::std::is_same_v; template constexpr bool is_arithmetic_except = cuda::std::is_arithmetic_v && !cuda::std::is_same_v; #define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ template < \ typename T, \ typename = cuda::std::enable_if_t>> \ __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ } \ template < \ typename T, \ typename = cuda::std::enable_if_t>> \ __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ } #define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ template < \ typename T, \ typename = cuda::std::enable_if_t>> \ __forceinline__ __device__ bool operator OP(HALF x, T y) { \ return HALF2FLOAT(x) OP static_cast(y); \ } \ template < \ typename T, \ typename = cuda::std::enable_if_t>> \ __forceinline__ __device__ bool operator OP(T x, HALF y) { \ return static_cast(y) OP HALF2FLOAT(x); \ } MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) MLX_DEFINE_HALF_CMP(__half, __half2float, <) MLX_DEFINE_HALF_CMP(__half, __half2float, >) MLX_DEFINE_HALF_CMP(__half, __half2float, <=) MLX_DEFINE_HALF_CMP(__half, __half2float, >=) MLX_DEFINE_HALF_CMP(__half, __half2float, ==) MLX_DEFINE_HALF_CMP(__half, __half2float, !=) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) #undef MLX_DEFINE_HALF_OP #undef MLX_DEFINE_HALF_CMP } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/gather.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/indexing.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include namespace mlx::core::cu { namespace cg = cooperative_groups; template __global__ void gather( const T* src, T* out, LocT size, const __grid_constant__ Shape src_shape, const __grid_constant__ Strides src_strides, int32_t src_ndim, const __grid_constant__ Shape slice_sizes, uint32_t slice_size, const __grid_constant__ cuda::std::array axes, const __grid_constant__ cuda::std::array indices, const __grid_constant__ cuda::std::array indices_shape, const __grid_constant__ cuda::std::array indices_strides) { LocT out_idx = cg::this_grid().thread_rank(); if (out_idx >= size) { return; } LocT src_elem = out_idx % slice_size; LocT idx_elem = out_idx / slice_size; LocT src_loc = elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim); #pragma unroll for (int i = 0; i < NIDX; ++i) { LocT idx_loc = elem_to_loc_nd( idx_elem, indices_shape.data() + i * IDX_NDIM, indices_strides.data() + i * IDX_NDIM); int32_t axis = axes[i]; LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); src_loc += idx_val * src_strides[axis]; } out[out_idx] = src[src_loc]; } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/gather_axis.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/indexing.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include namespace mlx::core::cu { namespace cg = cooperative_groups; template < typename T, typename IdxT, int NDIM, bool SrcC, bool IdxC, typename LocT> __global__ void gather_axis( const T* src, const IdxT* indices, T* out, LocT idx_size_pre, LocT idx_size_axis, LocT idx_size_post, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array src_strides, const __grid_constant__ cuda::std::array idx_strides, int32_t axis, int32_t axis_size, int64_t src_stride_axis, int64_t idx_stride_axis) { LocT index = cg::this_grid().thread_rank(); if (index >= idx_size_pre * idx_size_axis * idx_size_post) { return; } auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); LocT elem_idx = z * idx_size_post; LocT idx_loc = y * idx_stride_axis; if constexpr (IdxC) { idx_loc += elem_idx * idx_size_axis + x; } else { idx_loc += elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); } auto idx_val = absolute_index(indices[idx_loc], axis_size); LocT src_loc = idx_val * src_stride_axis; if constexpr (SrcC) { src_loc += elem_idx * axis_size + x; } else { src_loc += elem_to_loc_nd(elem_idx + x, shape.data(), src_strides.data()); } LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; out[out_idx] = src[src_loc]; } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/hadamard.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { __device__ __forceinline__ void hadamard_radix_m(float* x); template struct Pow2Log2 { static_assert( (N > 0) && ((N & (N - 1)) == 0), "N must be a positive power of two."); static constexpr int value = 1 + Pow2Log2::value; }; template <> struct Pow2Log2<1> { static constexpr int value = 0; }; template __device__ __forceinline__ void hadamard_radix_pow2(float* x) { constexpr int kLogR = Pow2Log2::value; int h = 1; #pragma unroll for (int s = 0; s < kLogR; ++s) { #pragma unroll for (int i = 0; i < R / 2; ++i) { int k = i & (h - 1); int j = ((i - k) << 1) + k; float a = x[j]; float b = x[j + h]; x[j] = a + b; x[j + h] = a - b; } h <<= 1; } } template __global__ void hadamard_n(const T* in, T* out, float scale, long long num_transforms) { constexpr int kNumThreads = N / max_radix; constexpr int kLogN = Pow2Log2::value; constexpr int kLogR = Pow2Log2::value; constexpr int kNumSteps = kLogN / kLogR; constexpr int kLogFinal = kLogN % kLogR; constexpr int kFinalRadix = 1 << kLogFinal; if (threadIdx.x >= kNumThreads) { return; } __shared__ T buf[N]; int i = threadIdx.x; for (long long transform = blockIdx.x; transform < num_transforms; transform += gridDim.x) { long long base = (transform / stride) * static_cast(N) * stride + (transform % stride); if constexpr (stride == 1) { #pragma unroll for (int j = 0; j < max_radix / read_width; ++j) { int index = j * read_width * kNumThreads + i * read_width; #pragma unroll for (int r = 0; r < read_width; ++r) { buf[index + r] = in[base + index + r]; } } } else { #pragma unroll for (int j = 0; j < max_radix; ++j) { buf[j * kNumThreads + i] = in[base + (j * kNumThreads + i) * stride]; } } __syncthreads(); float x[max_radix]; int h = 1; #pragma unroll for (int s = 0; s < kNumSteps; ++s) { int k = i & (h - 1); int j = ((i - k) << kLogR) + k; #pragma unroll for (int r = 0; r < max_radix; ++r) { x[r] = static_cast(buf[j + h * r]); } hadamard_radix_pow2(x); #pragma unroll for (int r = 0; r < max_radix; ++r) { buf[j + h * r] = static_cast(x[r]); } h <<= kLogR; __syncthreads(); } if constexpr (kFinalRadix > 1) { #pragma unroll for (int t = 0; t < max_radix / kFinalRadix; ++t) { int index = i + t * kNumThreads; int k = index & (h - 1); int j = ((index - k) << kLogFinal) + k; #pragma unroll for (int r = 0; r < kFinalRadix; ++r) { x[r] = static_cast(buf[j + h * r]); } hadamard_radix_pow2(x); #pragma unroll for (int r = 0; r < kFinalRadix; ++r) { buf[j + h * r] = static_cast(x[r]); } } __syncthreads(); } if constexpr (stride == 1) { #pragma unroll for (int j = 0; j < max_radix / read_width; ++j) { int index = j * read_width * kNumThreads + i * read_width; #pragma unroll for (int r = 0; r < read_width; ++r) { float val = static_cast(buf[index + r]); out[base + index + r] = static_cast(val * scale); } } } else { #pragma unroll for (int j = 0; j < max_radix; ++j) { out[base + (j * kNumThreads + i) * stride] = buf[j * kNumThreads + i]; } } __syncthreads(); } } template __global__ void hadamard_m(const T* in, T* out, float scale, long long num_tasks) { constexpr int kTasksPerBatch = N / read_width; for (long long task = blockIdx.x * blockDim.x + threadIdx.x; task < num_tasks; task += blockDim.x * gridDim.x) { long long i = task % kTasksPerBatch; long long batch = task / kTasksPerBatch; long long base = batch * static_cast(M) * N; float x[read_width][M]; #pragma unroll for (int c = 0; c < M; ++c) { #pragma unroll for (int r = 0; r < read_width; ++r) { x[r][c] = static_cast(in[base + c * N + i * read_width + r]); } } #pragma unroll for (int r = 0; r < read_width; ++r) { hadamard_radix_m(x[r]); } #pragma unroll for (int c = 0; c < M; ++c) { #pragma unroll for (int r = 0; r < read_width; ++r) { out[base + c * N + i * read_width + r] = static_cast(x[r][c] * scale); } } } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/indexing.cuh ================================================ // Copyright © 2025 Apple Inc. #include #include namespace mlx::core::cu { // Convert an absolute index to positions in a 3d grid, assuming the index is // calculated with: // index = x * dim1 * dim2 + y * dim2 + z template inline __host__ __device__ cuda::std::tuple index_to_dims(T index, T dim1, T dim2) { T x = index / (dim1 * dim2); T y = (index % (dim1 * dim2)) / dim2; T z = index % dim2; return cuda::std::make_tuple(x, y, z); } // Get absolute index from possible negative index. template inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { if constexpr (cuda::std::is_unsigned_v) { return idx; } else { return static_cast(idx < 0 ? idx + size : idx); } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/scatter.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/indexing.cuh" #include "mlx/backend/cuda/device/scatter_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include namespace mlx::core::cu { namespace cg = cooperative_groups; template < typename T, typename IdxT, typename Op, int NIDX, int IDX_NDIM, typename LocT> __global__ void scatter( const T* upd, T* out, LocT size, const __grid_constant__ Shape upd_shape, const __grid_constant__ Strides upd_strides, int32_t upd_ndim, LocT upd_post_idx_size, const __grid_constant__ Shape out_shape, const __grid_constant__ Strides out_strides, int32_t out_ndim, const __grid_constant__ cuda::std::array axes, const __grid_constant__ cuda::std::array indices, const __grid_constant__ cuda::std::array indices_shape, const __grid_constant__ cuda::std::array indices_strides) { LocT upd_idx = cg::this_grid().thread_rank(); if (upd_idx >= size) { return; } LocT out_elem = upd_idx % upd_post_idx_size; LocT idx_elem = upd_idx / upd_post_idx_size; LocT out_idx = elem_to_loc( out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim); #pragma unroll for (int i = 0; i < NIDX; ++i) { LocT idx_loc = elem_to_loc_nd( idx_elem, indices_shape.data() + i * IDX_NDIM, indices_strides.data() + i * IDX_NDIM); int32_t axis = axes[i]; LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); out_idx += idx_val * out_strides[axis]; } LocT upd_loc = elem_to_loc( out_elem + idx_elem * upd_post_idx_size, upd_shape.data(), upd_strides.data(), upd_ndim); Op{}(out + out_idx, upd[upd_loc]); } template __global__ void masked_scatter( const T* dst, const bool* mask, const int32_t* scatter_offsets, const T* src, T* out, IdxT size, IdxT src_batch_size, IdxT mask_batch_size, const __grid_constant__ Shape dst_shape, const __grid_constant__ Strides dst_strides, int32_t dst_ndim, const __grid_constant__ Shape src_shape, const __grid_constant__ Strides src_strides, int32_t src_ndim) { IdxT index = cg::this_grid().thread_rank(); if (index >= size) { return; } T dst_val; if constexpr (DstContiguous) { dst_val = dst[index]; } else { IdxT dst_loc = elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim); dst_val = dst[dst_loc]; } if (mask[index]) { IdxT src_index = static_cast(scatter_offsets[index]); if (src_index < src_batch_size) { IdxT batch_idx = index / mask_batch_size; if constexpr (SrcContiguous) { out[index] = src[batch_idx * src_batch_size + src_index]; } else { IdxT src_elem = batch_idx * src_batch_size + src_index; IdxT src_loc = elem_to_loc( src_elem, src_shape.data(), src_strides.data(), src_ndim); out[index] = src[src_loc]; } return; } } out[index] = dst_val; } template __global__ void masked_scatter_vec_contiguous( const T* dst, const bool* mask, const int32_t* scatter_offsets, const T* src, T* out, IdxT size, IdxT src_batch_size, IdxT mask_batch_size) { IdxT vec_index = cg::this_grid().thread_rank(); IdxT base = vec_index * N_READS; if (base >= size) { return; } auto out_vec = load_vector(dst, vec_index, size, static_cast(0)); auto mask_vec = load_vector(mask, vec_index, size, false); auto offset_vec = load_vector(scatter_offsets, vec_index, size, 0); #pragma unroll for (int i = 0; i < N_READS; ++i) { IdxT index = base + i; if (index >= size) { break; } if (mask_vec[i]) { IdxT src_index = static_cast(offset_vec[i]); if (src_index < src_batch_size) { IdxT batch_idx = index / mask_batch_size; out_vec[i] = src[batch_idx * src_batch_size + src_index]; } } } store_vector(out, vec_index, out_vec, size); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/scatter_axis.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/indexing.cuh" #include "mlx/backend/cuda/device/scatter_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include namespace mlx::core::cu { namespace cg = cooperative_groups; template < typename T, typename IdxT, typename Op, int NDIM, bool UpdC, bool IdxC, typename LocT> __global__ void scatter_axis( const T* upd, const IdxT* indices, T* out, LocT idx_size_pre, LocT idx_size_axis, LocT idx_size_post, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array upd_strides, const __grid_constant__ cuda::std::array idx_strides, int32_t axis, int32_t axis_size, int64_t upd_stride_axis, int64_t idx_stride_axis) { LocT index = cg::this_grid().thread_rank(); if (index >= idx_size_pre * idx_size_axis * idx_size_post) { return; } auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); LocT elem_idx = z * idx_size_post; LocT idx_loc = y * idx_stride_axis; if constexpr (IdxC) { idx_loc += elem_idx * idx_size_axis + x; } else { idx_loc += elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); } auto idx_val = absolute_index(indices[idx_loc], axis_size); LocT upd_loc = y * upd_stride_axis; if constexpr (UpdC) { upd_loc += elem_idx * idx_size_axis + x; } else { upd_loc += elem_to_loc_nd(elem_idx + x, shape.data(), upd_strides.data()); } LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; Op{}(out + out_idx, upd[upd_loc]); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/scatter_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/atomic_ops.cuh" namespace mlx::core::cu { struct ScatterAssign { template __device__ void operator()(T* out, T val) const { *out = val; } }; struct ScatterSum { template __device__ void operator()(T* out, T val) const { atomic_add(out, val); } }; struct ScatterProd { template __device__ void operator()(T* out, T val) const { atomic_prod(out, val); } }; struct ScatterMax { template __device__ void operator()(T* out, T val) const { atomic_max(out, val); } }; struct ScatterMin { template __device__ void operator()(T* out, T val) const { atomic_min(out, val); } }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/slice_update.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include namespace mlx::core::cu { namespace cg = cooperative_groups; template < typename T, typename IdxT, typename Op, bool OUT_ROW_CONTIG, bool UPD_ROW_CONTIG, bool UPD_SCALAR, int NWORK> __global__ void slice_update_op( const T* updates, T* out, int64_t update_size, const __grid_constant__ Shape update_shape, const __grid_constant__ Strides update_strides, int32_t update_ndim, const __grid_constant__ Strides output_strides, int64_t output_offset) { Op op; IdxT idx = cg::this_grid().thread_rank() * NWORK; IdxT out_idx; IdxT update_idx; if constexpr (OUT_ROW_CONTIG) { out_idx = idx; } else { out_idx = elem_to_loc( idx, update_shape.data(), output_strides.data(), update_ndim); } if constexpr (!UPD_SCALAR) { if constexpr (UPD_ROW_CONTIG) { update_idx = idx; } else { update_idx = elem_to_loc( idx, update_shape.data(), update_strides.data(), update_ndim); } } else { update_idx = 0; } out += output_offset; for (int j = 0; j < NWORK && idx < update_size; j++) { out[out_idx] = op(out[out_idx], updates[update_idx]); idx++; if constexpr (OUT_ROW_CONTIG) { out_idx = idx; } else { out_idx += output_strides[update_ndim - 1]; } if constexpr (UPD_ROW_CONTIG) { update_idx = idx; } else if constexpr (!UPD_SCALAR) { update_idx += update_strides[update_ndim - 1]; } } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/ternary_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once namespace mlx::core::cu { struct Select { template __device__ T operator()(bool condition, T x, T y) { return condition ? x : y; } }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/unary_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include #include #include namespace mlx::core::cu { struct Abs { template __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; } else { return cuda::std::abs(x); } } }; struct ArcCos { template __device__ T operator()(T x) { return cuda::std::acos(x); } }; struct ArcCosh { template __device__ T operator()(T x) { return cuda::std::acosh(x); } }; struct ArcSin { template __device__ T operator()(T x) { return cuda::std::asin(x); } }; struct ArcSinh { template __device__ T operator()(T x) { return cuda::std::asinh(x); } }; struct ArcTan { template __device__ T operator()(T x) { return cuda::std::atan(x); } }; struct ArcTanh { template __device__ T operator()(T x) { return cuda::std::atanh(x); } }; struct BitwiseInvert { template __device__ T operator()(T x) { return ~x; } }; struct Ceil { template __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; } else if constexpr (is_complex_v) { return T{cuda::std::ceil(x.real()), cuda::std::ceil(x.imag())}; } else { return cuda::std::ceil(x); } } }; struct Conjugate { template __device__ complex_t operator()(complex_t x) { return cuda::std::conj(x); } }; struct Cos { template __device__ T operator()(T x) { return cuda::std::cos(x); } }; struct Cosh { template __device__ T operator()(T x) { return cuda::std::cosh(x); } }; struct Erf { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return erf(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return erf(__bfloat162float(x)); } else { return erf(x); } } }; struct ErfInv { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return erfinv(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return erfinv(__bfloat162float(x)); } else { return erfinv(x); } } }; struct Exp { template __device__ T operator()(T x) { return cuda::std::exp(x); } }; struct Expm1 { template __device__ T operator()(T x) { return cuda::std::expm1(x); } }; struct Floor { template __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; } else if constexpr (is_complex_v) { return T{cuda::std::floor(x.real()), cuda::std::floor(x.imag())}; } else { return cuda::std::floor(x); } } }; struct Imag { template __device__ auto operator()(complex_t x) { return x.imag(); } }; struct Log { template __device__ T operator()(T x) { return cuda::std::log(x); } }; struct Log2 { template __device__ T operator()(T x) { if constexpr (is_complex_v) { auto y = Log{}(x); return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F}; } else { return cuda::std::log2(x); } } }; struct Log10 { template __device__ T operator()(T x) { return cuda::std::log10(x); } }; struct Log1p { template __device__ T operator()(T z) { if constexpr (is_complex_v) { float x = z.real(); float y = z.imag(); float zabs = Abs{}(z).real(); float theta = atan2f(y, x + 1); if (zabs < 0.5f) { float r = x * (2 + x) + y * y; if (r == 0) { // handle underflow return {x, theta}; } return {0.5f * log1pf(r), theta}; } else { float z0 = hypotf(x + 1, y); return {logf(z0), theta}; } } else { return cuda::std::log1p(z); } } }; struct LogicalNot { __device__ bool operator()(bool x) { return !x; } }; struct Negative { template __device__ T operator()(T x) { if constexpr (is_complex_v) { return T{0, 0} - x; } else { return -x; } } }; struct Real { template __device__ auto operator()(complex_t x) { return x.real(); } }; struct Round { template __device__ T operator()(T x) { if constexpr (is_complex_v) { return {cuda::std::rint(x.real()), cuda::std::rint(x.imag())}; } else { return cuda::std::rint(x); } } }; struct Sigmoid { template __device__ T operator()(T x) { T y = 1 / (1 + cuda::std::exp(cuda::std::abs(x))); return (x < 0) ? y : 1 - y; } }; struct Sign { template __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x != 0; } else if constexpr (is_complex_v) { if (x.real() == 0 && x.imag() == 0) { return x; } else { return x / Abs()(x); } } else if constexpr (cuda::std::is_same_v) { return static_cast((x > T(0.f)) - (x < T(0.f))); } else { return (x > T(0)) - (x < T(0)); } } }; struct Sin { template __device__ T operator()(T x) { return cuda::std::sin(x); } }; struct Sinh { template __device__ T operator()(T x) { return cuda::std::sinh(x); } }; struct Square { template __device__ T operator()(T x) { return x * x; } }; struct Sqrt { template __device__ T operator()(T x) { return cuda::std::sqrt(x); } }; struct Rsqrt { template __device__ T operator()(T x) { if constexpr (is_complex_v) { return 1.0f / Sqrt{}(x); } else if constexpr (cuda::std::is_same_v) { return rsqrt(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return rsqrt(__bfloat162float(x)); } else { return rsqrt(x); } } }; struct Tan { template __device__ T operator()(T x) { return cuda::std::tan(x); } }; struct Tanh { template __device__ T operator()(T x) { return cuda::std::tanh(x); } }; struct ToFP8 { template __device__ uint8_t operator()(T x) { return __nv_fp8_e4m3(x).__x; } }; struct FromFP8 { __device__ float operator()(uint8_t x) { return float(*(__nv_fp8_e4m3*)(&x)); } }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device/utils.cuh ================================================ // Copyright © 2025 Apple Inc. // This file must not include any host-only code, utilities that work under both // host and device can be put here. // // See more about the requirements at: // https://docs.nvidia.com/cuda/nvrtc/#language #pragma once #include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/config.h" #include #include #include #include #include namespace mlx::core::cu { /////////////////////////////////////////////////////////////////////////////// // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. using Shape = cuda::std::array; using Strides = cuda::std::array; // Vectorized load/store. template struct alignas(sizeof(T) * N) AlignedVector { T val[N]; __device__ T& operator[](int i) { return val[i]; } __device__ T operator[](int i) const { return val[i]; } }; template inline __host__ __device__ bool is_aligned(T* x) { return (reinterpret_cast(x) % (N * sizeof(T))) == 0; } template inline __device__ AlignedVector unsafe_load_vector( const T* ptr, uint32_t offset) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } template inline __device__ AlignedVector load_vector( const T* ptr, uint32_t offset) { if (is_aligned(ptr)) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = ptr[offset * N + i]; } return v; } } template inline __device__ AlignedVector load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) { if (is_aligned(ptr) && (offset + 1) * N <= size) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback; } return v; } } template inline __device__ AlignedVector load_vector( const T* ptr, uint32_t offset, SizeT size, int64_t stride, T fallback) { if (is_aligned(ptr) && stride == 1 && (offset + 1) * N <= size) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = (N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback; } return v; } } template inline __device__ void unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } template inline __device__ void store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { if (is_aligned(ptr)) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { #pragma unroll for (int i = 0; i < N; ++i) { ptr[offset * N + i] = vec[i]; } } } template inline __device__ void store_vector( T* ptr, uint32_t offset, const AlignedVector& vec, SizeT size) { if (is_aligned(ptr) && (offset + 1) * N <= size) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { for (int i = 0; (offset * N + i) < size && i < N; ++i) { ptr[offset * N + i] = vec[i]; } } } template inline __device__ void store_vector( T* ptr, uint32_t offset, const AlignedVector& vec, SizeT size, int64_t stride) { if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { for (int i = 0; (offset * N + i) < size && i < N; ++i) { ptr[stride * (offset * N + i)] = vec[i]; } } } /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// template struct Limits { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T min() { return cuda::std::numeric_limits::min(); } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { return cuda::std::numeric_limits::min(); } }; template struct Limits< T, cuda::std::enable_if_t< cuda::std::is_same_v || cuda::std::is_same_v>> { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { return -cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { return cuda::std::numeric_limits::lowest(); } }; // CUDA 11 does not have host side arithmetic operators for half types. template struct Limits< T, cuda::std::enable_if_t< cuda::std::is_same_v || cuda::std::is_same_v>> { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return -cuda::std::numeric_limits::infinity(); #else return -cuda::std::numeric_limits::infinity(); #endif } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return cuda::std::numeric_limits::lowest(); #else return cuda::std::numeric_limits::lowest(); #endif } }; template <> struct Limits { static constexpr __host__ __device__ bool max() { return true; } static constexpr __host__ __device__ bool min() { return false; } }; template struct Limits> { static constexpr __host__ __device__ complex_t max() { return {Limits::max(), Limits::max()}; } static constexpr __host__ __device__ complex_t min() { return {Limits::min(), Limits::min()}; } }; /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// template inline __host__ __device__ IdxT elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Optimize when the ndim is known at compile time. template inline __host__ __device__ IdxT elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { IdxT loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } template inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides) { IdxT a_loc = 0; IdxT b_loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, const int64_t* c_strides) { IdxT a_loc = 0; IdxT b_loc = 0; IdxT c_loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, int ndim) { IdxT a_loc = 0; IdxT b_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, const int64_t* c_strides, int ndim) { IdxT a_loc = 0; IdxT b_loc = 0; IdxT c_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); } /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// template struct LoopedElemToLoc { int dim; LoopedElemToLoc inner_looper; OffsetT offset{0}; int index{0}; __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} __device__ void next(const int* shape, const int64_t* strides) { if (dim == 0) { return; } index++; offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); offset = inner_looper.offset; } } __device__ void next(int n, const int* shape, const int64_t* strides) { if (dim == 0) { return; } index += n; offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; if (extra >= shape[dim - 1]) { inner_looper.next(1 + extra / shape[dim - 1], shape, strides); extra = extra % shape[dim - 1]; } else { inner_looper.next(shape, strides); } index = 0; offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); } } } __device__ OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, true, OffsetT> { int dim; OffsetT offset{0}; int index{0}; __device__ LoopedElemToLoc(int dim) : dim(dim) {} __device__ void next(const int* shape, const int64_t* strides) { index++; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset += OffsetT(strides[0]); } } __device__ void next(int n, const int* shape, const int64_t* strides) { index += n; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset = index * OffsetT(strides[0]); } } __device__ OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, false, OffsetT> { OffsetT offset{0}; __device__ LoopedElemToLoc(int) {} __device__ void next(const int*, const int64_t* strides) { offset += OffsetT(strides[0]); } __device__ void next(int n, const int*, const int64_t* strides) { offset += n * OffsetT(strides[0]); } __device__ OffsetT location() { return offset; } }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/worker.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/utils.h" #include #include #include #include namespace mlx::core::cu { namespace { bool use_cuda_graphs() { static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true); return use_graphs; } const char* save_cuda_graphs_dot_file() { static const char* filename = []() -> const char* { const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE"); if (env && std::strlen(env) == 0) { return nullptr; } return env; }(); return filename; } inline bool is_empty_dim(dim3 dim) { return (dim.x == 0 && dim.y == 0 && dim.z == 0) || (dim.x == 1 && dim.y == 1 && dim.z == 1); } } // namespace Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_)); CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &concurrent_managed_access_, cudaDevAttrConcurrentManagedAccess, device_)); CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &host_native_atomic_, cudaDevAttrHostNativeAtomicSupported, device_)); CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &managed_memory_, cudaDevAttrManagedMemory, device_)); CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &memory_pools_, cudaDevAttrMemoryPoolsSupported, device_)); } Device::~Device() { if (cudnn_handle_) { CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_handle_)); } if (cublaslt_handle_) { CHECK_CUBLAS_ERROR(cublasLtDestroy(cublaslt_handle_)); } } void Device::make_current() { // We need to set/get current CUDA device very frequently, cache it to reduce // actual calls of CUDA APIs. Use -1 as sentinel so the first call on each // new thread always calls cudaSetDevice (which establishes the CUDA primary // context). Without this, device 0 would never get set on a new thread. static thread_local int current = -1; if (current != device_) { CHECK_CUDA_ERROR(cudaSetDevice(device_)); current = device_; } } CommandEncoder& Device::get_command_encoder(Stream s) { auto it = encoders_.find(s.index); if (it == encoders_.end()) { it = encoders_.try_emplace(s.index, *this).first; } return it->second; } cublasLtHandle_t Device::get_cublaslt_handle() { if (!cublaslt_handle_) { make_current(); CHECK_CUBLAS_ERROR(cublasLtCreate(&cublaslt_handle_)); } return cublaslt_handle_; } cudnnHandle_t Device::get_cudnn_handle() { if (!cudnn_handle_) { make_current(); CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_handle_)); } return cudnn_handle_; } CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { enc.device().make_current(); if (!use_cuda_graphs()) { return; } CHECK_CUDA_ERROR( cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal)); } CommandEncoder::CaptureContext::~CaptureContext() { if (!use_cuda_graphs()) { enc.node_count_++; return; } graph.end_capture(enc.stream()); if (discard) { return; } enc.add_graph_node(graph); } CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) : enc(enc) { enc.in_concurrent_ = true; } CommandEncoder::ConcurrentContext::~ConcurrentContext() { enc.in_concurrent_ = false; if (!use_cuda_graphs()) { return; } // Use an empty graph node for synchronization CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)}; CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); // Insert the concurrent -> empty node dependencies for (auto& from : enc.concurrent_nodes_) { enc.from_nodes_.push_back(from.node); enc.to_nodes_.push_back(empty.node); enc.graph_deps_key_ += from.id; enc.graph_deps_key_ += "-"; enc.graph_deps_key_ += empty.id; enc.graph_deps_key_ += "-"; } // Insert the input -> concurrent node dependencies without updating output // nodes auto outputs = std::move(enc.active_outputs_); enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_)); // Update output node to be the empty node for (auto o : outputs) { enc.node_map_.emplace(o, empty).first->second = empty; } } void CommandEncoder::insert_graph_dependencies(GraphNode node) { node.id = std::to_string(node_count_++); if (in_concurrent_) { concurrent_nodes_.push_back(std::move(node)); } else { std::vector nodes; nodes.push_back(std::move(node)); insert_graph_dependencies(std::move(nodes)); } } void CommandEncoder::insert_graph_dependencies(std::vector nodes) { for (auto& node : nodes) { graph_nodes_key_ += node.node_type; graph_nodes_key_ += "-"; } std::vector deps; { // Dependencies must be added in the same order to produce a consistent // topology std::unordered_set set_deps; for (auto d : active_deps_) { if (auto it = node_map_.find(d); it != node_map_.end()) { auto [_, inserted] = set_deps.insert(it->second.node); if (inserted) { deps.push_back(it->second); } } } } active_deps_.clear(); for (auto o : active_outputs_) { for (auto& node : nodes) { node_map_.emplace(o, node).first->second = node; } } active_outputs_.clear(); for (auto& from : deps) { for (auto& to : nodes) { from_nodes_.push_back(from.node); to_nodes_.push_back(to.node); graph_deps_key_ += from.id; graph_deps_key_ += "-"; graph_deps_key_ += to.id; graph_deps_key_ += "-"; } } } // Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER std::pair get_graph_limits(Device& d) { auto cc = d.compute_capability_major() * 100 + d.compute_capability_minor() * 10; int ops = 20; int mb = 100; switch (cc) { case 800: // A100 ops = 20; mb = 400; break; case 900: // H100 case 1000: // B200 case 1200: // Consumer Blackwell ops = 100; mb = 1000; break; case 1210: // DGX Spark ops = 20; mb = 25; break; } return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)}; } CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d), graph_(d), worker_(d), graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) { std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d); } void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); } void CommandEncoder::set_input_array(const array& arr) { if (!use_cuda_graphs()) { return; } bytes_in_graph_ += arr.data_size(); auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); } void CommandEncoder::set_output_array(const array& arr) { if (!use_cuda_graphs()) { return; } auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); active_outputs_.push_back(id); } void CommandEncoder::add_kernel_node_raw( void* func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, uint32_t smem_bytes, void** params) { bool use_cluster = !is_empty_dim(cluster_dim); assert(!use_cluster || device_.compute_capability_major() >= 9); if (!use_cuda_graphs()) { node_count_++; cudaLaunchConfig_t config = {}; config.gridDim = grid_dim; config.blockDim = block_dim; config.dynamicSmemBytes = smem_bytes; config.stream = stream(); cudaLaunchAttribute attr = {}; if (use_cluster) { attr.id = cudaLaunchAttributeClusterDimension; attr.val.clusterDim.x = cluster_dim.x; attr.val.clusterDim.y = cluster_dim.y; attr.val.clusterDim.z = cluster_dim.z; config.attrs = &attr; config.numAttrs = 1; } CHECK_CUDA_ERROR(cudaLaunchKernelExC(&config, func, params)); return; } cudaKernelNodeParams kernel_params = {0}; kernel_params.func = func; kernel_params.gridDim = grid_dim; kernel_params.blockDim = block_dim; kernel_params.kernelParams = params; kernel_params.sharedMemBytes = smem_bytes; cudaGraphNode_t node = add_kernel_node_raw(kernel_params); if (use_cluster) { cudaKernelNodeAttrValue attr = {}; attr.clusterDim.x = cluster_dim.x; attr.clusterDim.y = cluster_dim.y; attr.clusterDim.z = cluster_dim.z; CHECK_CUDA_ERROR(cudaGraphKernelNodeSetAttribute( node, cudaLaunchAttributeClusterDimension, &attr)); } } void CommandEncoder::add_kernel_node_raw( CUfunction func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, uint32_t smem_bytes, void** params) { bool use_cluster = !is_empty_dim(cluster_dim); assert(!use_cluster || device_.compute_capability_major() >= 9); if (!use_cuda_graphs()) { node_count_++; CUlaunchConfig config = {}; config.gridDimX = grid_dim.x; config.gridDimY = grid_dim.y; config.gridDimZ = grid_dim.z; config.blockDimX = block_dim.x; config.blockDimY = block_dim.y; config.blockDimZ = block_dim.z; config.sharedMemBytes = smem_bytes; config.hStream = stream(); CUlaunchAttribute attr = {}; if (use_cluster) { attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; attr.value.clusterDim.x = cluster_dim.x; attr.value.clusterDim.y = cluster_dim.y; attr.value.clusterDim.z = cluster_dim.z; config.attrs = &attr; config.numAttrs = 1; } CHECK_CUDA_ERROR(cuLaunchKernelEx(&config, func, params, nullptr)); return; } CUDA_KERNEL_NODE_PARAMS kernel_params = {}; kernel_params.func = func; kernel_params.gridDimX = grid_dim.x; kernel_params.gridDimY = grid_dim.y; kernel_params.gridDimZ = grid_dim.z; kernel_params.blockDimX = block_dim.x; kernel_params.blockDimY = block_dim.y; kernel_params.blockDimZ = block_dim.z; kernel_params.kernelParams = params; kernel_params.sharedMemBytes = smem_bytes; CUgraphNode node = add_kernel_node_raw(kernel_params); if (use_cluster) { CUlaunchAttributeValue attr = {}; attr.clusterDim.x = cluster_dim.x; attr.clusterDim.y = cluster_dim.y; attr.clusterDim.z = cluster_dim.z; CHECK_CUDA_ERROR(cuGraphKernelNodeSetAttribute( node, CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION, &attr)); } } cudaGraphNode_t CommandEncoder::add_kernel_node_raw( const cudaKernelNodeParams& params) { cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); insert_graph_dependencies(GraphNode{node, "K"}); return node; } CUgraphNode CommandEncoder::add_kernel_node_raw( const CUDA_KERNEL_NODE_PARAMS& params) { CUgraphNode node; CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); insert_graph_dependencies(GraphNode{node, "K"}); return node; } std::pair subgraph_to_key(cudaGraph_t graph) { // Constructs a key representing the nodes of a sub-graph. // Also checks if the sub-graph is updatable as CUDA graphs do not get // updated correctly if a kernel node getting updated has a different cluster // shape than the node it's being updated with. std::string key = "("; size_t num_nodes = 0; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); if (num_nodes == 0) { return {key + ")", true}; } bool is_updatable = true; std::vector nodes(num_nodes); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); for (const auto& node : nodes) { if (!is_updatable) { break; } cudaGraphNodeType type; CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); switch (type) { case cudaGraphNodeTypeGraph: { // Try to be updatable for a structure like graph -> graph -> kernel cudaGraph_t child; CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); auto [subkey, sub_is_updatable] = subgraph_to_key(child); is_updatable &= sub_is_updatable; key += subkey; break; } case cudaGraphNodeTypeHost: key += "H"; break; case cudaGraphNodeTypeMemset: key += "M"; break; case cudaGraphNodeTypeKernel: { cudaLaunchAttributeValue cluster_dim; CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( node, cudaLaunchAttributeClusterDimension, &cluster_dim)); // Only allow dim.x to be greater than 1 if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { is_updatable = false; } else { key += "K"; key += std::to_string(cluster_dim.clusterDim.x); } break; } case cudaGraphNodeTypeWaitEvent: key += "W"; break; case cudaGraphNodeTypeEventRecord: key += "R"; break; default: is_updatable = false; } } key += ")"; return {key, is_updatable}; } void CommandEncoder::add_graph_node(cudaGraph_t child) { if (!use_cuda_graphs()) { node_count_++; CudaGraphExec graph_exec; graph_exec.instantiate(child); device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); return; } cudaGraphNode_t node; auto [sub_graph_key, is_updatable] = subgraph_to_key(child); is_graph_updatable_ &= is_updatable; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); insert_graph_dependencies(GraphNode{node, sub_graph_key}); } void CommandEncoder::add_graph_node( cudaGraph_t child, const std::string& subgraph_key, bool is_updatable) { if (!use_cuda_graphs()) { node_count_++; CudaGraphExec graph_exec; graph_exec.instantiate(child); device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); return; } is_graph_updatable_ &= is_updatable; cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); insert_graph_dependencies(GraphNode{node, subgraph_key}); } bool CommandEncoder::needs_commit() { return (node_count_ > max_ops_per_graph_) || ((bytes_in_graph_ >> 20) > max_mb_per_graph_); } void CommandEncoder::commit() { nvtx3::scoped_range r("CommandEncoder::commit"); if (!temporaries_.empty()) { add_completed_handler([temporaries = std::move(temporaries_)]() {}); } if (use_cuda_graphs() && node_count_ > 0) { if (!from_nodes_.empty()) { #if CUDART_VERSION >= 13000 CHECK_CUDA_ERROR(cudaGraphAddDependencies( graph_, from_nodes_.data(), to_nodes_.data(), nullptr, // edgeData from_nodes_.size())); #else CHECK_CUDA_ERROR(cudaGraphAddDependencies( graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); #endif } device_.make_current(); if (!is_graph_updatable_) { CudaGraphExec graph_exec; graph_exec.instantiate(graph_); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); } else { auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_; auto& graph_exec = graph_cache_[graph_key]; if (graph_exec != nullptr) { cudaGraphExecUpdateResult update_result; #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo info; cudaGraphExecUpdate(graph_exec, graph_, &info); update_result = info.result; #else cudaGraphNode_t error_node; cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); #endif // CUDART_VERSION >= 12000 if (update_result != cudaGraphExecUpdateSuccess) { cudaGetLastError(); // reset error graph_exec.reset(); } } if (graph_exec == nullptr) { graph_exec.instantiate(graph_); } CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); } // Save cuda graph to dot file if (const char* filename = save_cuda_graphs_dot_file(); filename) { static int count = 0; auto path = fmt::format("{}_{}.dot", filename, ++count); CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0)); } // Reset state from_nodes_.clear(); to_nodes_.clear(); graph_deps_key_.clear(); graph_nodes_key_.clear(); node_map_.clear(); graph_ = CudaGraph(device_); is_graph_updatable_ = true; } // Put completion handlers in a batch. worker_.commit(stream_); node_count_ = 0; bytes_in_graph_ = 0; } void CommandEncoder::synchronize() { CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_)); auto p = std::make_shared>(); std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); commit(); f.wait(); } Device& device(int cuda_device) { static auto devices = []() { std::vector devices; int device_count = gpu::device_count(); for (int i = 0; i < device_count; ++i) { devices.emplace_back(i); } // Initialize the jit module cache here ensures it is not unloaded before // any evaluation is done. get_jit_module_cache(); return devices; }(); return devices.at(cuda_device); } Device& device(mlx::core::Device d) { return device(d.index); } CommandEncoder& get_command_encoder(Stream s) { return device(s.device).get_command_encoder(s); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/worker.h" #include "mlx/stream.h" #include #include #include #include namespace mlx::core::cu { // Compute a key and updatability flag for a CUDA graph by walking its nodes. std::pair subgraph_to_key(cudaGraph_t graph); class CommandEncoder { public: struct CaptureContext { CaptureContext(CommandEncoder& enc); ~CaptureContext(); CudaGraph graph; CommandEncoder& enc; bool discard{false}; }; struct ConcurrentContext { ConcurrentContext(CommandEncoder& enc); ~ConcurrentContext(); CommandEncoder& enc; }; explicit CommandEncoder(Device& d); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; CaptureContext capture_context() { return CaptureContext{*this}; } ConcurrentContext concurrent_context() { return ConcurrentContext{*this}; } void set_input_array(const array& arr); void set_output_array(const array& arr); template void add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { add_kernel_node_ex(func, grid_dim, block_dim, {}, 0, params...); } template void add_kernel_node_ex( F* func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, uint32_t smem_bytes, Params&&... params) { constexpr size_t num = sizeof...(Params); void* ptrs[num]; size_t i = 0; ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( std::forward(params)), ...); add_kernel_node_raw( reinterpret_cast(func), grid_dim, block_dim, cluster_dim, smem_bytes, ptrs); } void add_kernel_node_raw( void* func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, uint32_t smem_bytes, void** params); void add_kernel_node_raw( CUfunction func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, uint32_t smem_bytes, void** params); void add_graph_node(cudaGraph_t child); void add_graph_node( cudaGraph_t child, const std::string& subgraph_key, bool is_updatable); void add_temporary(const array& arr) { temporaries_.push_back(arr.data_shared_ptr()); } void add_completed_handler(std::function task); bool needs_commit(); void commit(); Device& device() { return device_; } CudaStream& stream() { return stream_; } // Wait until kernels and completion handlers are finished void synchronize(); private: cudaGraphNode_t add_kernel_node_raw(const cudaKernelNodeParams& params); CUgraphNode add_kernel_node_raw(const CUDA_KERNEL_NODE_PARAMS& params); struct GraphNode { cudaGraphNode_t node; // K = kernel // E = empty // () = subgraph (with metadata) // Symbols ':', '-' are reserved as separators std::string node_type; std::string id; }; void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(std::vector nodes); Device& device_; CudaStream stream_; CudaGraph graph_; Worker worker_; int node_count_{0}; bool in_concurrent_{false}; std::vector from_nodes_; std::vector to_nodes_; std::string graph_nodes_key_; std::string graph_deps_key_; std::vector concurrent_nodes_; std::vector> temporaries_; LRUCache graph_cache_; std::vector active_deps_; std::vector active_outputs_; std::unordered_map node_map_; size_t bytes_in_graph_{0}; bool is_graph_updatable_{true}; int max_ops_per_graph_; int max_mb_per_graph_; }; class Device { public: explicit Device(int device); ~Device(); Device(Device&&) = default; Device(const Device&) = delete; Device& operator=(const Device&) = delete; // Make this device the current cuda device, this method is thread-safe. void make_current(); CommandEncoder& get_command_encoder(Stream s); cublasLtHandle_t get_cublaslt_handle(); cudnnHandle_t get_cudnn_handle(); int cuda_device() const { return device_; } int compute_capability_major() const { return compute_capability_major_; } int compute_capability_minor() const { return compute_capability_minor_; } bool concurrent_managed_access() const { return concurrent_managed_access_ == 1; } bool host_native_atomic() const { return host_native_atomic_ == 1; } bool managed_memory() const { return managed_memory_ == 1; } bool memory_pools() const { return memory_pools_ == 1; } private: int device_; int compute_capability_major_; int compute_capability_minor_; int concurrent_managed_access_; int host_native_atomic_; int managed_memory_; int memory_pools_; std::string device_name_; cublasLtHandle_t cublaslt_handle_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; std::unordered_map encoders_; }; Device& device(int cuda_device); Device& device(mlx::core::Device d); CommandEncoder& get_command_encoder(Stream s); } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/device_info.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/gpu/device_info.h" #include "mlx/backend/cuda/cuda.h" #include #include #include #include #include #include namespace mlx::core { namespace { // NVML dynamic loading for accurate memory reporting // (cudaMemGetInfo only sees current process) typedef int nvmlReturn_t; typedef struct nvmlDevice_st* nvmlDevice_t; struct nvmlMemory_t { unsigned long long total; unsigned long long free; unsigned long long used; }; struct NVMLState { void* handle = nullptr; nvmlReturn_t (*nvmlInit_v2)() = nullptr; nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char*, nvmlDevice_t*) = nullptr; nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t*) = nullptr; }; bool nvml_init(NVMLState& nvml) { #ifdef _WIN32 nvml.handle = dlopen("nvml.dll", RTLD_LAZY); if (!nvml.handle) { nvml.handle = dlopen( "C:\\Program Files\\NVIDIA Corporation\\NVSMI\\nvml.dll", RTLD_LAZY); } #else nvml.handle = dlopen("libnvidia-ml.so.1", RTLD_LAZY); #endif if (!nvml.handle) return false; nvml.nvmlInit_v2 = (decltype(nvml.nvmlInit_v2))dlsym(nvml.handle, "nvmlInit_v2"); nvml.nvmlDeviceGetHandleByUUID = (decltype(nvml.nvmlDeviceGetHandleByUUID))dlsym( nvml.handle, "nvmlDeviceGetHandleByUUID"); nvml.nvmlDeviceGetMemoryInfo = (decltype(nvml.nvmlDeviceGetMemoryInfo))dlsym( nvml.handle, "nvmlDeviceGetMemoryInfo"); if (!nvml.nvmlInit_v2 || !nvml.nvmlDeviceGetHandleByUUID || !nvml.nvmlDeviceGetMemoryInfo) { return false; } return nvml.nvmlInit_v2() == 0; } bool nvml_get_memory( NVMLState& nvml, const char* uuid, size_t* free, size_t* total) { if (!nvml.handle) return false; nvmlDevice_t device; if (nvml.nvmlDeviceGetHandleByUUID(uuid, &device) != 0) return false; nvmlMemory_t mem; if (nvml.nvmlDeviceGetMemoryInfo(device, &mem) != 0) return false; *free = mem.free; *total = mem.total; return true; } std::string format_uuid(const cudaUUID_t& uuid) { char buf[64]; snprintf( buf, sizeof(buf), "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", (unsigned char)uuid.bytes[0], (unsigned char)uuid.bytes[1], (unsigned char)uuid.bytes[2], (unsigned char)uuid.bytes[3], (unsigned char)uuid.bytes[4], (unsigned char)uuid.bytes[5], (unsigned char)uuid.bytes[6], (unsigned char)uuid.bytes[7], (unsigned char)uuid.bytes[8], (unsigned char)uuid.bytes[9], (unsigned char)uuid.bytes[10], (unsigned char)uuid.bytes[11], (unsigned char)uuid.bytes[12], (unsigned char)uuid.bytes[13], (unsigned char)uuid.bytes[14], (unsigned char)uuid.bytes[15]); return buf; } const std::unordered_map>& device_info_impl(int device_index) { // Static cache of device properties including UUID (needed for NVML lookup) static auto all_devices = []() { // Get device count int count = 0; cudaGetDeviceCount(&count); // Collect info for all devices struct DeviceInfo { std::unordered_map> info; std::string uuid; }; std::vector devices; for (int i = 0; i < count; ++i) { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, i); DeviceInfo dev; dev.info["device_name"] = std::string(prop.name); dev.uuid = format_uuid(prop.uuid); dev.info["uuid"] = dev.uuid; // Architecture string (e.g., "sm_89") char arch[16]; snprintf(arch, sizeof(arch), "sm_%d%d", prop.major, prop.minor); dev.info["architecture"] = std::string(arch); // PCI bus ID (domain:bus:device.function) char pci_id[32]; snprintf( pci_id, sizeof(pci_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev.info["pci_bus_id"] = std::string(pci_id); // Compute capability as size_t (to match Metal's variant type) dev.info["compute_capability_major"] = static_cast(prop.major); dev.info["compute_capability_minor"] = static_cast(prop.minor); devices.push_back(std::move(dev)); } return devices; }(); // Initialize NVML once for fresh memory reads static NVMLState nvml; static bool nvml_initialized = nvml_init(nvml); if (device_index < 0 || device_index >= static_cast(all_devices.size())) { static auto empty = std::unordered_map>(); return empty; } // Return a copy with fresh memory info // Using thread_local to avoid locks while keeping free_memory fresh thread_local auto device_info_copy = std::unordered_map>(); device_info_copy = all_devices[device_index].info; // Get fresh memory info - try NVML first (system-wide), fallback to // cudaMemGetInfo (process-level) size_t free_mem, total_mem; if (nvml_initialized && nvml_get_memory( nvml, all_devices[device_index].uuid.c_str(), &free_mem, &total_mem)) { // NVML succeeded - use system-wide memory } else { // Fallback to cudaMemGetInfo (process-scoped) int prev_device; cudaGetDevice(&prev_device); cudaSetDevice(device_index); cudaMemGetInfo(&free_mem, &total_mem); cudaSetDevice(prev_device); } device_info_copy["free_memory"] = free_mem; device_info_copy["total_memory"] = total_mem; return device_info_copy; } } // anonymous namespace namespace gpu { bool is_available() { return true; } int device_count() { int count = 0; cudaGetDeviceCount(&count); return count; } const std::unordered_map>& device_info(int device_index) { return device_info_impl(device_index); } } // namespace gpu namespace cu { bool is_available() { return true; } } // namespace cu } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/distributed.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/distributed/primitives.h" #include "mlx/primitives.h" #include namespace mlx::core::distributed { void AllReduce::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto set_input_output = [&](const array& in, array& out) -> std::pair { if (!in.flags().row_contiguous) { copy_gpu(in, out, CopyType::General, s); return {out, out}; } else if (in.is_donatable()) { out.copy_shared_buffer(in); return {in, out}; } else { out.set_data(cu::malloc_async(out.nbytes(), encoder)); return {in, out}; } }; auto [input, output] = set_input_output(inputs[0], outputs[0]); encoder.set_input_array(input); encoder.set_output_array(output); auto capture = encoder.capture_context(); switch (reduce_type_) { case Sum: distributed::detail::all_sum(group(), input, output, s); break; case Max: distributed::detail::all_max(group(), input, output, s); break; case Min: distributed::detail::all_min(group(), input, output, s); break; default: throw std::runtime_error( "Only all reduce sum, max, and min are supported."); } } void AllGather::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto ensure_contiguous = [&s, &encoder](const array& x) { if (x.flags().row_contiguous) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } }; auto input = ensure_contiguous(inputs[0]); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder)); encoder.set_input_array(input); encoder.set_output_array(outputs[0]); auto capture = encoder.capture_context(); distributed::detail::all_gather(group(), input, outputs[0], s); } void ReduceScatter::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto ensure_contiguous = [&s, &encoder](const array& x) { if (x.flags().row_contiguous) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } }; auto input = ensure_contiguous(inputs[0]); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder)); encoder.set_input_array(input); encoder.set_output_array(outputs[0]); auto capture = encoder.capture_context(); switch (reduce_type_) { case Sum: distributed::detail::sum_scatter(group(), input, outputs[0], s); break; default: throw std::runtime_error("Only sum scatter is supported. "); } } } // namespace mlx::core::distributed ================================================ FILE: mlx/backend/cuda/eval.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/gpu/eval.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" #include namespace mlx::core::gpu { void new_stream(Stream s) { // Force initalization of CUDA, so CUDA runtime get destroyed at last. cudaFree(nullptr); // Make sure CUDA event pool get destroyed after device and stream. cu::CudaEvent::init_pool(); // Ensure the static stream objects get created. cu::get_command_encoder(s); } void eval(array& arr) { nvtx3::scoped_range r("gpu::eval"); // Ensure CUDA context is active on this thread. Required when MLX is called // from threads that have not yet established a CUDA context (e.g. thread // pools, language runtimes that migrate work across OS threads). cu::device(arr.primitive().stream().device).make_current(); auto outputs = arr.outputs(); { // If the array is a tracer hold a reference // to its inputs so they don't get donated std::vector inputs; if (arr.is_tracer()) { inputs = arr.inputs(); } arr.primitive().eval_gpu(arr.inputs(), outputs); } auto& stream = arr.primitive().stream(); auto& encoder = cu::get_command_encoder(stream); // Keep used buffers alive until kernel finishes running. for (auto& in : arr.inputs()) { // Except for the donated one. if (in.data_shared_ptr() != arr.data_shared_ptr()) { encoder.add_temporary(in); } } for (auto& s : arr.siblings()) { encoder.add_temporary(s); } if (encoder.needs_commit()) { scheduler::notify_new_task(stream); encoder.add_completed_handler( [stream]() { scheduler::notify_task_completion(stream); }); encoder.commit(); } } void finalize(Stream s) { nvtx3::scoped_range r("gpu::finalize"); cu::get_command_encoder(s).commit(); } void synchronize(Stream s) { nvtx3::scoped_range r("gpu::synchronize"); cu::get_command_encoder(s).synchronize(); } } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/cuda/event.cu ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/event.h" #include "mlx/scheduler.h" #include #include #include namespace mlx::core { namespace cu { /////////////////////////////////////////////////////////////////////////////// // CudaEvent implementations /////////////////////////////////////////////////////////////////////////////// namespace { // Manage cached cudaEvent_t objects. class CudaEventPool { public: CudaEventHandle create(Device& d, int flags) { if (!on_creation_thread()) { return CudaEventHandle(d, flags); } auto& cache = cache_for(d, flags); if (cache.empty()) { return CudaEventHandle(d, flags); } else { CudaEventHandle ret = std::move(cache.back()); cache.pop_back(); return ret; } } void release(CudaEventHandle event) { if (!on_creation_thread()) { // Event will be destroyed directly instead of getting moved to cache. return; } cache_for(event.device, event.flags).push_back(std::move(event)); } private: std::vector& cache_for(Device& d, int flags) { return cache_[d.cuda_device()][flags]; } bool on_creation_thread() { return std::this_thread::get_id() == thread_id_; } // The CudaEvent may be created and destroyed on different threads (for // example when waiting on GPU work in CPU stream), we don't want to make // the cache thread-safe as it adds overhead, so we just skip cache when // using events in worker threads. std::thread::id thread_id_{std::this_thread::get_id()}; // {device: {flags: [events]}} std::map>> cache_; }; CudaEventPool& cuda_event_pool() { static CudaEventPool pool; return pool; } } // namespace CudaEventHandle::CudaEventHandle(Device& d, int flags) : device(d), flags(flags) { device.make_current(); CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags)); assert(handle_ != nullptr); } CudaEvent::CudaEvent(Device& d, int flags) : event_(cuda_event_pool().create(d, flags)) {} CudaEvent::~CudaEvent() { cuda_event_pool().release(std::move(event_)); } void CudaEvent::wait() { nvtx3::scoped_range r("cu::CudaEvent::wait"); event_.device.make_current(); cudaEventSynchronize(event_); } void CudaEvent::wait(cudaStream_t stream) { event_.device.make_current(); cudaStreamWaitEvent(stream, event_); } void CudaEvent::record(cudaStream_t stream) { event_.device.make_current(); cudaEventRecord(event_, stream); } bool CudaEvent::completed() const { // Note: cudaEventQuery can be safely called from any device. return cudaEventQuery(event_) == cudaSuccess; } // static void CudaEvent::init_pool() { cuda_event_pool(); } // Wraps CudaEvent with a few features: // 1. The class can be copied. // 2. Make wait/record work with CPU streams. // 3. Add checks for waiting on un-recorded event. class CopyableCudaEvent { public: explicit CopyableCudaEvent(Device& d) : event_( std::make_shared( d, cudaEventDisableTiming | cudaEventBlockingSync)) {} void wait() { event_->wait(); } void wait(Stream s) { if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this]() mutable { check_recorded(); event_->wait(); }); } else { check_recorded(); auto& encoder = cu::get_command_encoder(s); encoder.commit(); event_->wait(encoder.stream()); } } void record(Stream s) { if (s.device == mlx::core::Device::cpu) { throw std::runtime_error("CudaEvent can not wait on CPU stream."); } else { auto& encoder = cu::get_command_encoder(s); encoder.commit(); event_->record(encoder.stream()); recorded_ = true; } } bool is_signaled() const { return recorded_ && event_->completed(); } private: void check_recorded() const { if (!recorded_) { throw std::runtime_error( "Should not wait on a CudaEvent before recording."); } } std::shared_ptr event_; bool recorded_{false}; }; /////////////////////////////////////////////////////////////////////////////// // AtomicEvent implementations /////////////////////////////////////////////////////////////////////////////// __host__ __device__ void event_wait(uint32_t* ptr, uint32_t value) { cuda::atomic_ref ac(*ptr); uint32_t current; while ((current = ac.load()) < value) { ac.wait(current); } } __host__ __device__ void event_signal(uint32_t* ptr, uint32_t value) { cuda::atomic_ref ac(*ptr); ac.store(value); ac.notify_all(); } __global__ void event_wait_kernel(uint32_t* ptr, uint32_t value) { event_wait(ptr, value); } __global__ void event_signal_kernel(uint32_t* ptr, uint32_t value) { __threadfence_system(); event_signal(ptr, value); __threadfence_system(); } auto check_gpu_coherency() { static auto coherency = []() { int device_count = gpu::device_count(); bool concurrent_managed_access = true; bool host_native_atomic = true; for (int i = 0; i < device_count; ++i) { auto& d = cu::device(i); concurrent_managed_access &= d.concurrent_managed_access(); host_native_atomic &= d.host_native_atomic(); } return std::make_tuple(concurrent_managed_access, host_native_atomic); }(); return coherency; } AtomicEvent::AtomicEvent(Device& d) { void* buf; cudaError_t (*cuda_free)(void*); // There are 3 kinds of systems we are implementing for: // 1. concurrentManagedAccess == true // => use cuda::atom_ref on managed memory // 2. hostNativeAtomicSupported == true // => use cuda::atom_ref on pinned host memory // 2. no hardware cpu/gpu coherency // => use cuda::atom_ref on device memory d.make_current(); auto [concurrent_managed_access, host_native_atomic] = check_gpu_coherency(); if (concurrent_managed_access) { CHECK_CUDA_ERROR(cudaMallocManaged(&buf, sizeof(uint32_t))); cuda_free = cudaFree; coherent_ = true; } else if (host_native_atomic) { CHECK_CUDA_ERROR(cudaMallocHost(&buf, sizeof(uint32_t))); cuda_free = cudaFreeHost; coherent_ = true; } else { CHECK_CUDA_ERROR(cudaMalloc(&buf, sizeof(uint32_t))); cuda_free = cudaFree; coherent_ = false; } buf_ = std::shared_ptr( buf, [cuda_free](void* buf) { CHECK_CUDA_ERROR(cuda_free(buf)); }); if (coherent_) { *ptr() = 0; } else { CHECK_CUDA_ERROR(cudaMemset(buf, 0, sizeof(uint32_t))); } } void AtomicEvent::wait(uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::wait"); if (coherent_) { event_wait(ptr(), value); } else { while (!is_signaled(value)) { std::this_thread::yield(); } } } void AtomicEvent::wait(cudaStream_t stream, uint32_t value) { event_wait_kernel<<<1, 1, 0, stream>>>(ptr(), value); } void AtomicEvent::wait(Stream s, uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::wait(s)"); if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); wait(encoder.stream(), value); encoder.add_completed_handler([buf = buf_]() {}); } } void AtomicEvent::signal(uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::signal"); if (coherent_) { event_signal(ptr(), value); } else { signal(signal_stream(), value); } } void AtomicEvent::signal(cudaStream_t stream, uint32_t value) { event_signal_kernel<<<1, 1, 0, stream>>>(ptr(), value); } void AtomicEvent::signal(Stream s, uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::signal(s)"); if (s.device == mlx::core::Device::cpu) { // Signal through a GPU stream so the atomic is updated in GPU - updating // the atomic in CPU sometimes does not get GPU notified. scheduler::enqueue( s, [*this, value]() mutable { signal(signal_stream(), value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); signal(encoder.stream(), value); encoder.add_completed_handler([buf = buf_]() {}); } } bool AtomicEvent::is_signaled(uint32_t val) const { return value() >= val; } uint32_t AtomicEvent::value() const { nvtx3::scoped_range r("cu::AtomicEvent::value"); if (coherent_) { cuda::atomic_ref ac(*ptr()); return ac.load(); } else { uint32_t val; CHECK_CUDA_ERROR( cudaMemcpy(&val, ptr(), sizeof(uint32_t), cudaMemcpyDeviceToHost)); return val; } } const CudaStream& AtomicEvent::signal_stream() { static CudaStream stream(device(0)); return stream; } } // namespace cu /////////////////////////////////////////////////////////////////////////////// // Event implementations /////////////////////////////////////////////////////////////////////////////// namespace { struct EventImpl { // CudaEvent is preferred when possible because it is fast, however we have // to fallback to AtomicEvent in following cases: // 1. the event is used to wait/signal a cpu stream; // 2. signal value other than 1 has been specified. std::unique_ptr cuda; std::unique_ptr atomic; bool is_created() const { return cuda || atomic; } void ensure_created(Stream s, uint64_t signal_value) { if (is_created()) { return; } auto& d = cu::device(s.device); if (s.device == mlx::core::Device::cpu || signal_value > 1) { nvtx3::mark("Using slow AtomicEvent"); atomic = std::make_unique(d); } else { cuda = std::make_unique(d); } } }; } // namespace Event::Event(Stream s) : stream_(s) { event_ = std::shared_ptr( new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); } void Event::wait() { auto* event = static_cast(event_.get()); assert(event->is_created()); if (event->cuda) { assert(value() == 1); event->cuda->wait(); } else { event->atomic->wait(value()); } CHECK_CUDA_ERROR(cudaPeekAtLastError()); } void Event::wait(Stream s) { auto* event = static_cast(event_.get()); assert(event->is_created()); if (event->cuda) { assert(value() == 1); event->cuda->wait(s); } else { event->atomic->wait(s, value()); } } void Event::signal(Stream s) { auto* event = static_cast(event_.get()); event->ensure_created(s, value()); if (event->cuda) { assert(value() == 1); event->cuda->record(s); } else { event->atomic->signal(s, value()); } } bool Event::is_signaled() const { auto* event = static_cast(event_.get()); if (!event->is_created()) { return false; } if (event->cuda) { assert(value() == 1); return event->cuda->is_signaled(); } else { return event->atomic->is_signaled(value()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/event.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/allocator.h" #include "mlx/backend/cuda/utils.h" #include "mlx/stream.h" #include #include #include namespace mlx::core::cu { class Device; // RAII-managed move-only wrapper of cudaEvent_t. struct CudaEventHandle : public CudaHandle { CudaEventHandle(Device& d, int flags); Device& device; int flags; }; // Wrapper of native cuda event. It can synchronize between GPU streams, or wait // on GPU stream in CPU stream, but can not wait on CPU stream. class CudaEvent { public: CudaEvent(Device& d, int flags); ~CudaEvent(); CudaEvent(CudaEvent&&) = default; CudaEvent& operator=(CudaEvent&&) = default; CudaEvent(const CudaEvent&) = delete; CudaEvent& operator=(const CudaEvent&) = delete; void wait(); void wait(cudaStream_t stream); void record(cudaStream_t stream); // Return whether the recorded kernels have completed. Note that this method // returns true if record() has not been called. bool completed() const; // Internal: make sure event pool is initialized. static void init_pool(); private: CudaEventHandle event_; }; // Event that can synchronize between CPU and GPU. It is much slower than // CudaEvent so the latter should always be preferred when possible. class AtomicEvent { public: AtomicEvent(Device& d); void wait(uint32_t value); void wait(cudaStream_t stream, uint32_t value); void wait(Stream s, uint32_t value); void signal(uint32_t value); void signal(cudaStream_t stream, uint32_t value); void signal(Stream s, uint32_t value); bool is_signaled(uint32_t value) const; uint32_t value() const; private: const CudaStream& signal_stream(); uint32_t* ptr() const { return static_cast(buf_.get()); } bool coherent_; std::shared_ptr buf_; }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/fence.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" namespace mlx::core { struct FenceImpl { uint32_t count; cu::AtomicEvent event; }; Fence::Fence(Stream s) { fence_ = std::shared_ptr( new FenceImpl{0, cu::device(s.device)}, [](void* ptr) { delete static_cast(ptr); }); } void Fence::wait(Stream s, const array&) { auto* fence = static_cast(fence_.get()); fence->event.wait(fence->count); } void Fence::update(Stream s, const array& a, bool cross_device) { auto* fence = static_cast(fence_.get()); if (cross_device) { // Move to managed memory if there is a device switch auto& cbuf = *static_cast(const_cast(a).buffer().ptr()); if (cbuf.device != -1) { auto& encoder = cu::get_command_encoder(s); encoder.commit(); cu::allocator().move_to_unified_memory(cbuf, encoder.stream()); } } fence->count++; fence->event.signal(s, fence->count); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/fft.cu ================================================ // Copyright © 2025 Apple Inc. #include #include #include #include #include #include #include #include #include #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void scale_fft_output(T* out, T scale, size_t size) { auto index = cg::this_grid().thread_rank(); if (index < size) { out[index] *= scale; } } } // namespace cu namespace { void check_cufft_error(const char* name, cufftResult err) { if (err != CUFFT_SUCCESS) { throw std::runtime_error( std::string(name) + " failed with code: " + std::to_string(static_cast(err)) + "."); } } #define CHECK_CUFFT_ERROR(cmd) check_cufft_error(#cmd, (cmd)) enum class FFTTransformType : uint8_t { C2C = 0, R2C = 1, C2R = 2, }; struct FFTPlanKey { int device_id; FFTTransformType transform_type; int64_t n; int64_t batch; }; struct CuFFTPlan { explicit CuFFTPlan(int device_id, cufftHandle handle, size_t workspace_size) : device_id(device_id), handle(handle), workspace_size(workspace_size) {} ~CuFFTPlan() { if (handle != 0) { try { cu::device(device_id).make_current(); cufftDestroy(handle); } catch (...) { } } } int device_id; cufftHandle handle; size_t workspace_size; }; struct OrderedArray { array arr; std::vector order; }; auto& fft_plan_cache() { static LRUBytesKeyCache> cache( "MLX_CUDA_FFT_CACHE_SIZE", /* default_capacity */ 128); return cache; } FFTPlanKey make_plan_key( int device_id, FFTTransformType transform_type, int64_t n, int64_t batch) { FFTPlanKey key{}; key.device_id = device_id; key.transform_type = transform_type; key.n = n; key.batch = batch; return key; } cudaDataType_t input_type(FFTTransformType transform_type) { switch (transform_type) { case FFTTransformType::C2C: case FFTTransformType::C2R: return CUDA_C_32F; case FFTTransformType::R2C: return CUDA_R_32F; } throw std::runtime_error("[FFT] Unsupported cuFFT input transform type."); } cudaDataType_t output_type(FFTTransformType transform_type) { switch (transform_type) { case FFTTransformType::C2C: case FFTTransformType::R2C: return CUDA_C_32F; case FFTTransformType::C2R: return CUDA_R_32F; } throw std::runtime_error("[FFT] Unsupported cuFFT output transform type."); } cudaDataType_t execution_type(FFTTransformType transform_type) { switch (transform_type) { case FFTTransformType::C2C: return CUDA_C_32F; case FFTTransformType::R2C: return CUDA_R_32F; case FFTTransformType::C2R: return CUDA_C_32F; } throw std::runtime_error("[FFT] Unsupported cuFFT execution transform type."); } int64_t input_embed(FFTTransformType transform_type, int64_t n) { return transform_type == FFTTransformType::C2R ? (n / 2 + 1) : n; } int64_t output_embed(FFTTransformType transform_type, int64_t n) { return transform_type == FFTTransformType::R2C ? (n / 2 + 1) : n; } int exec_direction(FFTTransformType transform_type, bool inverse) { switch (transform_type) { case FFTTransformType::C2C: return inverse ? CUFFT_INVERSE : CUFFT_FORWARD; case FFTTransformType::R2C: return CUFFT_FORWARD; case FFTTransformType::C2R: return CUFFT_INVERSE; } throw std::runtime_error("[FFT] Unsupported cuFFT execution direction."); } std::shared_ptr get_fft_plan( cu::CommandEncoder& encoder, FFTTransformType transform_type, int64_t n, int64_t batch) { auto key = BytesKey{}; key.pod = make_plan_key(encoder.device().cuda_device(), transform_type, n, batch); auto& cache = fft_plan_cache(); if (auto entry = cache.find(key); entry != cache.end()) { return entry->second; } encoder.device().make_current(); cufftHandle handle = 0; size_t workspace_size = 0; try { CHECK_CUFFT_ERROR(cufftCreate(&handle)); CHECK_CUFFT_ERROR(cufftSetAutoAllocation(handle, 0)); CHECK_CUFFT_ERROR(cufftSetStream(handle, encoder.stream())); long long plan_n[1] = {n}; long long inembed[1] = {input_embed(transform_type, n)}; long long onembed[1] = {output_embed(transform_type, n)}; CHECK_CUFFT_ERROR(cufftXtMakePlanMany( handle, /* rank= */ 1, plan_n, inembed, /* istride= */ 1, /* idist= */ input_embed(transform_type, n), input_type(transform_type), onembed, /* ostride= */ 1, /* odist= */ output_embed(transform_type, n), output_type(transform_type), batch, &workspace_size, execution_type(transform_type))); } catch (...) { if (handle != 0) { encoder.device().make_current(); cufftDestroy(handle); } throw; } auto plan = std::make_shared( encoder.device().cuda_device(), handle, workspace_size); return cache.emplace(key, plan).first->second; } std::vector make_identity_order(int ndim) { std::vector order(ndim); std::iota(order.begin(), order.end(), 0); return order; } std::vector move_axis_to_back_permutation(int ndim, int axis_pos) { std::vector perm; perm.reserve(ndim); for (int i = 0; i < ndim; ++i) { if (i != axis_pos) { perm.push_back(i); } } perm.push_back(axis_pos); return perm; } std::vector apply_permutation( const std::vector& values, const std::vector& perm) { std::vector out(perm.size()); for (int i = 0; i < perm.size(); ++i) { out[i] = values[perm[i]]; } return out; } int find_axis_position(const std::vector& order, int axis) { auto it = std::find(order.begin(), order.end(), axis); if (it == order.end()) { throw std::runtime_error("[FFT] Internal axis tracking mismatch."); } return static_cast(it - order.begin()); } OrderedArray prepare_input( const OrderedArray& current, int axis, bool allow_direct, cu::CommandEncoder& encoder, Stream s) { int axis_pos = find_axis_position(current.order, axis); bool axis_last = axis_pos == static_cast(current.order.size()) - 1; bool direct = allow_direct && axis_last && current.arr.flags().row_contiguous; if (direct) { return current; } array view = current.arr; std::vector order = current.order; if (!axis_last) { auto perm = move_axis_to_back_permutation(current.arr.ndim(), axis_pos); view = transpose_in_eval(current.arr, perm); order = apply_permutation(current.order, perm); } array packed = contiguous_copy_gpu(view, s); encoder.add_temporary(packed); return {std::move(packed), std::move(order)}; } void execute_fft( const array& in, array& out, FFTTransformType transform_type, bool inverse, cu::CommandEncoder& encoder) { if (!in.flags().row_contiguous || in.strides(-1) != 1) { throw std::runtime_error("[FFT] Expected packed row-contiguous FFT input."); } int64_t n = transform_type == FFTTransformType::C2R ? out.shape(-1) : in.shape(-1); int64_t batch = in.shape().empty() ? 1 : in.size() / in.shape(-1); auto plan = get_fft_plan(encoder, transform_type, n, batch); encoder.set_input_array(in); out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_output_array(out); encoder.add_completed_handler([plan]() {}); encoder.device().make_current(); CHECK_CUFFT_ERROR(cufftSetStream(plan->handle, encoder.stream())); auto* workspace = allocate_workspace(encoder, plan->workspace_size); CHECK_CUFFT_ERROR(cufftSetWorkArea(plan->handle, workspace)); auto capture = encoder.capture_context(); CHECK_CUFFT_ERROR(cufftXtExec( plan->handle, gpu_ptr(in), gpu_ptr(out), exec_direction(transform_type, inverse))); } void restore_output_layout(const OrderedArray& current, array& out) { Strides out_strides(out.ndim()); for (int i = 0; i < current.order.size(); ++i) { out_strides[current.order[i]] = current.arr.strides(i); } auto [data_size, row_contiguous, col_contiguous] = check_contiguity(out.shape(), out_strides); bool contiguous = current.arr.flags().contiguous && data_size == current.arr.data_size(); out.copy_shared_buffer( current.arr, out_strides, {contiguous, row_contiguous, col_contiguous}, current.arr.data_size()); } void apply_inverse_scale( array& arr, const std::vector& axes, const array& out, cu::CommandEncoder& encoder) { if (axes.empty()) { return; } double scale = 1.0; for (auto axis : axes) { scale /= out.shape(axis); } size_t size = arr.data_size(); dim3 block_dims(256); dim3 grid_dims((size + block_dims.x - 1) / block_dims.x); encoder.set_input_array(arr); encoder.set_output_array(arr); if (arr.dtype() == float32) { float scale_f = static_cast(scale); encoder.add_kernel_node( cu::scale_fft_output, grid_dims, block_dims, gpu_ptr(arr), scale_f, size); } else if (arr.dtype() == complex64) { cu::complex64_t scale_f(static_cast(scale), 0.0f); encoder.add_kernel_node( cu::scale_fft_output, grid_dims, block_dims, gpu_ptr(arr), scale_f, size); } else { throw std::runtime_error("[FFT] Unsupported dtype for inverse scaling."); } } } // namespace void FFT::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("FFT::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& in = inputs[0]; if (out.size() == 0) { return; } auto order = make_identity_order(in.ndim()); OrderedArray current{in, std::move(order)}; std::vector axis_sequence; axis_sequence.reserve(axes_.size()); if (inverse_) { for (auto axis : axes_) { axis_sequence.push_back(static_cast(axis)); } } else { for (int i = static_cast(axes_.size()) - 1; i >= 0; --i) { axis_sequence.push_back(static_cast(axes_[i])); } } int real_axis = axes_.empty() ? -1 : static_cast(axes_.back()); for (int i = 0; i < axis_sequence.size(); ++i) { int axis = axis_sequence[i]; bool step_real = real_ && axis == real_axis; auto transform_type = step_real ? (inverse_ ? FFTTransformType::C2R : FFTTransformType::R2C) : FFTTransformType::C2C; // cuFFT may overwrite the input buffer for C2R, so only use the direct // input when the transform is out-of-place from the library's perspective // or when the original input may be donated to the output. auto prepared = prepare_input( current, axis, /* allow_direct= */ transform_type != FFTTransformType::C2R || is_donatable(in, out), encoder, s); Shape step_shape = prepared.arr.shape(); if (step_real) { step_shape.back() = out.shape(axis); } Dtype step_dtype = transform_type == FFTTransformType::C2R ? float32 : complex64; array step_out(std::move(step_shape), step_dtype, nullptr, {}); execute_fft(prepared.arr, step_out, transform_type, inverse_, encoder); encoder.add_temporary(step_out); current = {std::move(step_out), std::move(prepared.order)}; } if (inverse_) { apply_inverse_scale(current.arr, axes_, out, encoder); } restore_output_layout(current, out); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/cublas_gemm.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" #include "mlx/utils.h" #include namespace mlx::core { namespace { cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: return CUBLAS_COMPUTE_32F; case bfloat16: return CUBLAS_COMPUTE_32F; case float32: return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; case float64: return CUBLAS_COMPUTE_64F; case complex64: return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; default: throw std::runtime_error( fmt::format( "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); } } } // namespace CublasGemm::CublasGemm( cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride) { scale_type_ = cublas_utils::dtype_to_cublas_type(dtype, "CublasGemm"); if (dtype == bfloat16 || dtype == float16) { scale_type_ = CUDA_R_32F; } cudaDataType_t cublas_dtype = cublas_utils::dtype_to_cublas_type(dtype, "CublasGemm"); init_base( device, scale_type_, dtype_to_compute_type(dtype), cublas_dtype, cublas_dtype, a_transposed, a_rows, a_cols, lda, b_transposed, b_rows, b_cols, ldb, batch_count, a_batch_stride, b_batch_stride); // alpha and beta are both host pointers cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); } CublasGemm::CublasGemm( cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride) : CublasGemm( device, dtype, a_transposed, a_rows, a_cols, lda, b_transposed, b_rows, b_cols, ldb, batch_count, a_batch_stride, b_batch_stride) { auto type = cublas_utils::dtype_to_cublas_type(dtype, "CublasGemm"); c_desc_ = cublas_utils::create_matrix_layout( type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); } void CublasGemm::set_out( Dtype dtype, bool transposed, uint64_t rows, uint64_t cols, int64_t ld, int32_t batch_count, int64_t batch_stride) { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); out_desc_ = cublas_utils::create_matrix_layout( cublas_utils::dtype_to_cublas_type(dtype, "CublasGemm"), cols, rows, transposed, ld, batch_count, batch_stride); } void CublasGemm::run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, float alpha) { int batch_count = out.size() / (M_ * N_); if (batch_count / batch_shape.back() > 1) { run_batched( encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); return; } encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); execute( encoder, gpu_ptr(out), gpu_ptr(a), gpu_ptr(b), nullptr, alpha); } void CublasGemm::run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, const Strides& c_batch_strides, float alpha, float beta) { int batch_count = out.size() / (M_ * N_); if (batch_count / batch_shape.back() > 1) { run_batched( encoder, out, a, b, c, batch_shape, a_batch_strides, b_batch_strides, c_batch_strides, alpha, beta); return; } encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); execute( encoder, gpu_ptr(out), gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), alpha, beta); } void CublasGemm::execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* c, const float alpha /* = 1 */, const float beta /* = 0 */) { const void* alpha_ptr = α const void* beta_ptr = β complex64_t alpha_c, beta_c; if (scale_type_ == CUDA_C_32F) { alpha_c = complex64_t{alpha, 0.0f}; beta_c = complex64_t{beta, 0.0f}; alpha_ptr = &alpha_c; beta_ptr = &beta_c; } execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/cublas_gemm.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include namespace mlx::core { class CublasGemm : public CublasMatmulBase { public: CublasGemm( cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride); CublasGemm( cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride); // The output's descriptor is inferred from inputs by default, use this method // for unusual output. void set_out( Dtype dtype, bool transposed, uint64_t rows, uint64_t cols, int64_t ld, int32_t batch_count, int64_t batch_stride); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, float alpha = 1.0f); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, const Strides& c_batch_strides, float alpha, float beta); private: void run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, float alpha); void run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, const Strides& c_batch_strides, float alpha, float beta); void execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* c, float alpha = 1, float beta = 0); }; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" namespace mlx::core { void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, float alpha) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { execute( encoder, gpu_ptr(out) + out.itemsize() * i * batch_shape.back() * M_ * N_, gpu_ptr(a) + a.itemsize() * a_it.loc, gpu_ptr(b) + b.itemsize() * b_it.loc, nullptr, alpha); a_it.step(); b_it.step(); } } void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, const Strides& c_batch_strides, float alpha, float beta) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { execute( encoder, gpu_ptr(out) + out.itemsize() * i * batch_shape.back() * M_ * N_, gpu_ptr(a) + a.itemsize() * a_it.loc, gpu_ptr(b) + b.itemsize() * b_it.loc, gpu_ptr(c) + c.itemsize() * c_it.loc, alpha, beta); a_it.step(); b_it.step(); c_it.step(); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void set_mm_device_pointers_nd( int8_t** pointers, int8_t* a_start, int8_t* b_start, int8_t* out_start, int item_size, const __grid_constant__ cuda::std::array batch_shape, const __grid_constant__ cuda::std::array a_batch_strides, const __grid_constant__ cuda::std::array b_batch_strides, int64_t batch_stride, int batch_count) { auto index = cg::this_grid().thread_rank(); if (index >= batch_count) { return; } auto [a_offset, b_offset] = elem_to_loc_nd( index, batch_shape.data(), a_batch_strides.data(), b_batch_strides.data()); pointers[index] = a_start + item_size * a_offset; pointers[index + batch_count] = b_start + item_size * b_offset; pointers[index + 2 * batch_count] = out_start + item_size * index * batch_stride; } __global__ void set_mm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, int8_t* out_start, int item_size, const __grid_constant__ Shape batch_shape, const __grid_constant__ Strides a_batch_strides, const __grid_constant__ Strides b_batch_strides, int64_t batch_stride, int batch_ndim, int batch_count) { auto index = cg::this_grid().thread_rank(); if (index >= batch_count) { return; } auto [a_offset, b_offset] = elem_to_loc( index, batch_shape.data(), a_batch_strides.data(), b_batch_strides.data(), batch_ndim); pointers[index] = a_start + item_size * a_offset; pointers[index + batch_count] = b_start + item_size * b_offset; pointers[index + 2 * batch_count] = out_start + item_size * index * batch_stride; } template __global__ void set_addmm_device_pointers_nd( int8_t** pointers, int8_t* a_start, int8_t* b_start, int8_t* c_start, int8_t* out_start, int item_size, const __grid_constant__ cuda::std::array batch_shape, const __grid_constant__ cuda::std::array a_batch_strides, const __grid_constant__ cuda::std::array b_batch_strides, const __grid_constant__ cuda::std::array c_batch_strides, int64_t batch_stride, int batch_count) { auto index = cg::this_grid().thread_rank(); if (index >= batch_count) { return; } auto [a_offset, b_offset, c_offset] = elem_to_loc_nd( index, batch_shape.data(), a_batch_strides.data(), b_batch_strides.data(), c_batch_strides.data()); pointers[index] = a_start + item_size * a_offset; pointers[index + batch_count] = b_start + item_size * b_offset; pointers[index + 2 * batch_count] = c_start + item_size * c_offset; pointers[index + 3 * batch_count] = out_start + item_size * index * batch_stride; } __global__ void set_addmm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, int8_t* c_start, int8_t* out_start, int item_size, const __grid_constant__ Shape batch_shape, const __grid_constant__ Strides a_batch_strides, const __grid_constant__ Strides b_batch_strides, const __grid_constant__ Strides c_batch_strides, int64_t batch_stride, int batch_ndim, int batch_count) { auto index = cg::this_grid().thread_rank(); if (index >= batch_count) { return; } auto [a_offset, b_offset, c_offset] = elem_to_loc( index, batch_shape.data(), a_batch_strides.data(), b_batch_strides.data(), c_batch_strides.data(), batch_ndim); pointers[index] = a_start + item_size * a_offset; pointers[index + batch_count] = b_start + item_size * b_offset; pointers[index + 2 * batch_count] = c_start + item_size * c_offset; pointers[index + 3 * batch_count] = out_start + item_size * index * batch_stride; } } // namespace cu namespace { void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, &batch_mode, sizeof(batch_mode))); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); } } // namespace void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, float alpha) { int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(out_desc_, batch_count); // Launch kernel to set device offsets auto pointers = array( cu::malloc_async(batch_count * sizeof(void*) * 3, encoder), {batch_count * 3}, uint64); encoder.add_temporary(pointers); encoder.set_output_array(pointers); int block_dims = std::min(batch_count, 256); int num_blocks = cuda::ceil_div(batch_count, block_dims); int64_t batch_stride = M_ * N_; int item_size = out.itemsize(); int ndim = batch_shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { encoder.add_kernel_node( cu::set_mm_device_pointers_nd, num_blocks, block_dims, gpu_ptr(pointers), gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), const_param(b_batch_strides), batch_stride, batch_count); }); } else { encoder.add_kernel_node( cu::set_mm_device_pointers_g, num_blocks, block_dims, gpu_ptr(pointers), gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), const_param(b_batch_strides), batch_stride, ndim, batch_count); } // Run matmul encoder.set_input_array(pointers); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto out_pointers = b_pointers + batch_count; execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), reinterpret_cast(b_pointers), nullptr, alpha); } void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, const Shape& batch_shape, const Strides& a_batch_strides, const Strides& b_batch_strides, const Strides& c_batch_strides, float alpha, float beta) { int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(c_desc_, batch_count); set_pointer_mode(out_desc_, batch_count); // Launch kernel to set device offsets auto pointers = array( cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder), {batch_count * 4}, uint64); encoder.add_temporary(pointers); encoder.set_output_array(pointers); int block_dims = std::min(batch_count, 256); int num_blocks = cuda::ceil_div(batch_count, block_dims); int64_t batch_stride = M_ * N_; int item_size = out.itemsize(); int ndim = batch_shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { encoder.add_kernel_node( cu::set_addmm_device_pointers_nd, num_blocks, block_dims, gpu_ptr(pointers), gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), const_param(b_batch_strides), const_param(c_batch_strides), batch_stride, batch_count); }); } else { encoder.add_kernel_node( cu::set_addmm_device_pointers_g, num_blocks, block_dims, gpu_ptr(pointers), gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), const_param(b_batch_strides), const_param(c_batch_strides), batch_stride, ndim, batch_count); } // Run matmul encoder.set_input_array(pointers); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto c_pointers = b_pointers + batch_count; auto out_pointers = c_pointers + batch_count; execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), reinterpret_cast(b_pointers), reinterpret_cast(c_pointers), alpha, beta); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/gemv.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include #include namespace mlx::core::cu { namespace cg = cooperative_groups; static constexpr int rows_per_block = 8; // Accumulator type selection per input element type T. template struct GemvAccType { using type = T; }; template <> struct GemvAccType<__half> { using type = float; }; template <> struct GemvAccType<__nv_bfloat16> { using type = float; }; template <> struct GemvAccType { using type = float; }; template <> struct GemvAccType { using type = double; }; template <> struct GemvAccType { using type = cu::complex64_t; }; template __device__ void gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); auto g_idx = block.group_index(); auto t_idx = block.thread_index(); int row = g_idx.x * rows_per_block + t_idx.y; if (row < rows) { using Acc = typename GemvAccType::type; Acc sum = Acc(0); for (int col = n_per_thread * warp.thread_rank(); col < cols; col += (WARP_SIZE * n_per_thread)) { auto local_mat = unsafe_load_vector(mat + row * cols + col, 0); auto local_vec = unsafe_load_vector(vec + col, 0); #pragma unroll for (int j = 0; j < n_per_thread; ++j) { sum += static_cast(local_mat[j]) * static_cast(local_vec[j]); } } sum = cg::reduce(warp, sum, cg::plus{}); if (warp.thread_rank() == 0) { out[row] = static_cast(sum); } } } template __global__ void gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) { gemv_impl(mat, vec, out, rows, cols); } template __global__ void gemv_batched( const T* mat, const T* vec, T* out, int rows, int cols, const __grid_constant__ Shape batch_shape, const __grid_constant__ Strides mat_batch_strides, const __grid_constant__ Strides vec_batch_strides, int batch_ndim) { auto block = cg::this_thread_block(); auto batch_idx = block.group_index().y; auto [vec_offset, mat_offset] = elem_to_loc( batch_idx, batch_shape.data(), vec_batch_strides.data(), mat_batch_strides.data(), batch_ndim); gemv_impl( mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols); } template __global__ void gemv_gather( const T* mat, const T* vec, T* out, uint32_t* mat_indices, uint32_t* vec_indices, int rows, int cols, const __grid_constant__ Shape mat_batch_shape, const __grid_constant__ Strides mat_batch_strides, int mat_batch_ndim, const __grid_constant__ Shape vec_batch_shape, const __grid_constant__ Strides vec_batch_strides, int vec_batch_ndim, const __grid_constant__ Shape index_shape, const __grid_constant__ Strides mat_index_strides, const __grid_constant__ Strides vec_index_strides, int index_batch_ndim) { auto block = cg::this_thread_block(); auto indices_idx = block.group_index().y; uint32_t index_mat, index_vec; if (index_batch_ndim > 1) { auto [mat_idx_offset, vec_idx_offset] = elem_to_loc( indices_idx, index_shape.data(), mat_index_strides.data(), vec_index_strides.data(), index_batch_ndim); index_mat = mat_indices[mat_idx_offset]; index_vec = vec_indices[vec_idx_offset]; } else { index_mat = mat_indices[indices_idx * mat_index_strides[0]]; index_vec = vec_indices[indices_idx * vec_index_strides[0]]; } int64_t mat_offset; if (mat_batch_ndim > 1) { mat_offset = elem_to_loc( index_mat, mat_batch_shape.data(), mat_batch_strides.data(), mat_batch_ndim); } else { mat_offset = index_mat * mat_batch_strides[0]; } int64_t vec_offset; if (vec_batch_ndim > 1) { vec_offset = elem_to_loc( index_vec, vec_batch_shape.data(), vec_batch_strides.data(), vec_batch_ndim); } else { vec_offset = index_vec * vec_batch_strides[0]; } gemv_impl( mat + mat_offset, vec + vec_offset, out + indices_idx * rows, rows, cols); } bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); } template void dispatch_n_per_thread(int n_per_thread, F&& f) { switch (n_per_thread) { case 1: f(std::integral_constant{}); break; case 2: f(std::integral_constant{}); break; case 4: f(std::integral_constant{}); break; } } void gemv( const array& a, const array& b, array& out, int M, int N, int K, uint32_t batch_count, const mlx::core::Shape& batch_shape, const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& b_batch_strides, CommandEncoder& encoder) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) { using DataType = cuda_type_t; dim3 block_dims{WARP_SIZE, rows_per_block}; const DataType* mat; const DataType* vec; int rows; int cols = K; auto mat_strides = const_param(a_batch_strides); auto vec_strides = const_param(b_batch_strides); if (M == 1) { mat = gpu_ptr(b); vec = gpu_ptr(a); rows = N; std::swap(mat_strides, vec_strides); } else { mat = gpu_ptr(a); vec = gpu_ptr(b); rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; int n_per_t; if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) { n_per_t = 4; } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) { n_per_t = 2; } else { n_per_t = 1; } dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { if (batch_count == 1) { auto kernel = gemv_single; encoder.add_kernel_node( kernel, num_blocks_x, block_dims, mat, vec, gpu_ptr(out), rows, cols); } else { auto kernel = gemv_batched; encoder.add_kernel_node( kernel, dim3{num_blocks_x, batch_count}, block_dims, mat, vec, gpu_ptr(out), rows, cols, const_param(batch_shape), mat_strides, vec_strides, batch_shape.size()); } }); }); } void gather_mv( const array& mat_, const array& vec_, const array& mat_indices, const array& vec_indices, array& out, int N, int K, CommandEncoder& encoder) { encoder.set_input_array(mat_); encoder.set_input_array(vec_); encoder.set_input_array(mat_indices); encoder.set_input_array(vec_indices); encoder.set_output_array(out); dispatch_inexact_types(out.dtype(), "gather_mv", [&](auto type_tag) { using DataType = cuda_type_t; dim3 block_dims{WARP_SIZE, rows_per_block}; int rows = N; int cols = K; uint32_t batch_size = static_cast(out.size() / N); const DataType* mat = gpu_ptr(mat_); const DataType* vec = gpu_ptr(vec_); uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; int n_per_t; if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) { n_per_t = 4; } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) { n_per_t = 2; } else { n_per_t = 1; } dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { auto kernel = gemv_gather; encoder.add_kernel_node( kernel, dim3{num_blocks_x, batch_size}, block_dims, mat, vec, gpu_ptr(out), gpu_ptr(mat_indices), gpu_ptr(vec_indices), rows, cols, const_param(mat_.shape()), const_param(mat_.strides()), mat_.ndim() - 2, const_param(vec_.shape()), const_param(vec_.strides()), vec_.ndim() - 2, const_param(mat_indices.shape()), const_param(mat_indices.strides()), const_param(vec_indices.strides()), mat_indices.ndim()); }); }); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/gemms/gemv.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" namespace mlx::core::cu { bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); void gemv( const array& a, const array& b, array& out, int M, int N, int K, uint32_t batch_count, const mlx::core::Shape& batch_shape, const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& b_batch_strides, CommandEncoder& encoder); void gather_mv( const array& mat, const array& vec, const array& mat_indices, const array& vec_indices, array& out, int N, int K, CommandEncoder& encoder); } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/gemms/grouped_gemm.h ================================================ // Copyright © 2025 Apple Inc. #pragma once namespace mlx::core { namespace cu { class CommandEncoder; } class array; void cutlass_grouped_gemm_unaligned( bool a_transposed, int lda, bool b_transposed, int ldb, int group_count, const array& a, const array& b, const array& indices, array& out, cu::CommandEncoder& encoder); void cutlass_segmented_mm( bool a_transposed, int lda, bool b_transposed, int ldb, int num_segments, int M, int N, const array& a, const array& b, const array& segments, array& out, cu::CommandEncoder& encoder); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/cutlass_utils.cuh" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/grouped_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include #include #include #include #include namespace mlx::core { using ProblemSize = cutlass::gemm::GemmCoord; namespace cu { namespace cg = cooperative_groups; template __global__ void prepare_grouped_mm_data( const uint32_t* indices, size_t size, int group_count, int K, int N, int lda, int ldb, int item_size, int8_t* a_start, int8_t* b_start, int8_t* out_start, int a_batch_stride, int b_batch_stride, int out_batch_stride, ProblemSize* problem_sizes, int64_t* a_lds, int64_t* b_lds, int64_t* out_lds, void** a_ptrs, void** b_ptrs, void** out_ptrs) { auto block = cg::this_thread_block(); // cumsum(histogram(indices)) - offset for each group. extern __shared__ uint32_t cum_histo[]; int group = block.thread_rank(); if (group < group_count) { cum_histo[group] = 0; } block.sync(); // Since |indices| is sorted, the position where element changes would be its // cumulative histogram. size_t elems_per_block = block.num_threads() * N_READS; for (int r = 0; r < cuda::ceil_div(size, elems_per_block); ++r) { // TODO: Use vectorized read. for (int i = 0; i < N_READS; ++i) { size_t pos = r * elems_per_block + group * N_READS + i; if (pos >= size) { break; } auto elem = indices[pos]; auto next = pos < size - 1 ? indices[pos + 1] : group_count; while (elem < next) { cum_histo[elem] = pos + 1; elem++; } } } block.sync(); if (group < group_count) { // Fill shapes. int delta = group == 0 ? cum_histo[0] : cum_histo[group] - cum_histo[group - 1]; problem_sizes[group] = {delta, N, K}; a_lds[group] = lda; b_lds[group] = ldb; out_lds[group] = N; // Fill pointers. auto offset = group == 0 ? 0 : cum_histo[group - 1]; a_ptrs[group] = a_start + offset * item_size * a_batch_stride; b_ptrs[group] = b_start + group * item_size * b_batch_stride; out_ptrs[group] = out_start + offset * item_size * out_batch_stride; } } __global__ void prepare_segmented_mm_data( const uint32_t* segments, int num_segments, int M, int N, int lda, int ldb, int item_size, bool a_transposed, bool b_transposed, int8_t* a_start, int8_t* b_start, int8_t* out_start, ProblemSize* problem_sizes, int64_t* a_lds, int64_t* b_lds, int64_t* out_lds, void** a_ptrs, void** b_ptrs, void** out_ptrs) { int idx = cg::this_grid().thread_rank(); if (idx >= num_segments) return; int64_t start = segments[2 * idx]; int64_t end = segments[2 * idx + 1]; int K_i = (end > start) ? static_cast(end - start) : 0; problem_sizes[idx] = {M, N, K_i}; a_lds[idx] = lda; b_lds[idx] = ldb; out_lds[idx] = N; // Offset into K dimension depends on layout: // A [M,K]: row-major offset = start, col-major offset = start * lda // B [K,N]: row-major offset = start * ldb, col-major offset = start int64_t a_offset = a_transposed ? start * lda : start; int64_t b_offset = b_transposed ? start : start * ldb; a_ptrs[idx] = a_start + a_offset * item_size; b_ptrs[idx] = b_start + b_offset * item_size; out_ptrs[idx] = out_start + static_cast(idx) * M * N * item_size; } } // namespace cu namespace { // Shared GEMM configuration for every type and arch. template struct CommonGemmConfiguration { using Element = T; using Arch = ArchTag; using Accumulator = std::conditional_t<(sizeof(T) < 4), float, T>; using EpilogueOutputOp = cutlass::epilogue::thread:: LinearCombination; }; // Slow GEMM configuration as fallback. template < typename T, typename Arch, int kAlignmentC = 1, bool kEnableTF32 = false, typename Enable = void> struct GemmConfiguration : public CommonGemmConfiguration { using OpClass = cutlass::arch::OpClassSimt; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; static const int kAlignmentAB = 1; static const int kStages = 2; }; // Specialized GEMM configuration for sm80 and later. template struct GemmConfiguration< T, Arch, kAlignmentC, true, std::enable_if_t= 80 && sizeof(T) <= 4>> : public CommonGemmConfiguration { using OpClass = cutlass::arch::OpClassTensorOp; using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32 / sizeof(T)>; static const int kAlignmentAB = 1; static const int kStages = 2; }; // Specialized GEMM configuration for tf32 on sm80. template struct GemmConfiguration : public CommonGemmConfiguration { using OpClass = cutlass::arch::OpClassTensorOp; using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; static const int kAlignmentAB = 1; static const int kStages = 3; // use SM80_CP_ASYNC }; // Get direct access to kernel. template class GemmGroupedEncoder : public cutlass::gemm::device::GemmGrouped { public: void encode(cu::CommandEncoder& encoder) { encoder.add_kernel_node_ex( cutlass::Kernel, {static_cast(this->params_.threadblock_count), 1, 1}, {GemmKernel::kThreadCount, 1, 1}, {}, sizeof(typename GemmKernel::SharedStorage), this->params_); } }; // Invoke the grouped GEMM of CUTLASS 2.x API, which supports small alignments. template void grouped_gemm_v2( bool a_transposed, bool b_transposed, int group_count, ProblemSize* problem_sizes, int64_t* a_lds, int64_t* b_lds, int64_t* out_lds, void* a_ptrs, void* b_ptrs, void* out_ptrs, cu::CommandEncoder& encoder) { dispatch_bool(a_transposed, [&](auto a_transposed_tag) { dispatch_bool(b_transposed, [&](auto b_transposed_tag) { using LayoutA = std::conditional_t< a_transposed_tag.value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; using LayoutB = std::conditional_t< b_transposed_tag.value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< typename GemmConfiguration::Element, LayoutA, cutlass::ComplexTransform::kNone, GemmConfiguration::kAlignmentAB, typename GemmConfiguration::Element, LayoutB, cutlass::ComplexTransform::kNone, GemmConfiguration::kAlignmentAB, typename GemmConfiguration::Element, cutlass::layout::RowMajor, typename GemmConfiguration::Accumulator, typename GemmConfiguration::OpClass, typename GemmConfiguration::Arch, typename GemmConfiguration::ThreadblockShape, typename GemmConfiguration::WarpShape, typename GemmConfiguration::InstructionShape, typename GemmConfiguration::EpilogueOutputOp, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, GemmConfiguration::kStages>::GemmKernel; using GemmGrouped = GemmGroupedEncoder; static int threadblock_count = GemmGrouped::sufficient(); typename GemmGrouped::Arguments args( problem_sizes, group_count, threadblock_count, {/* alpha */ 1, /* beta */ 0}, reinterpret_cast(a_ptrs), reinterpret_cast(b_ptrs), reinterpret_cast(out_ptrs), reinterpret_cast(out_ptrs), a_lds, b_lds, out_lds, out_lds); GemmGrouped gemm; CHECK_CUTLASS_ERROR(gemm.initialize( args, allocate_workspace(encoder, gemm.get_workspace_size(args)), encoder.stream())); gemm.encode(encoder); }); }); } template void dispatch_cutlass_arch(cu::Device& device, F&& f) { if (device.compute_capability_major() < 8) { f(type_identity{}); } else if (device.compute_capability_major() == 8) { f(type_identity{}); } else { f(type_identity{}); } } auto* get_grouped_mm_funcion(Dtype dtype, int N, cu::Device& device) { auto* fun = grouped_gemm_v2>; dispatch_float_types(dtype, "grouped_gemm_v2", [&](auto type_tag) { using DataType = cutlass_type_t; dispatch_cutlass_arch(device, [&](auto arch_tag) { using Arch = MLX_GET_TYPE(arch_tag); dispatch_bool(N % 8 == 0, [&](auto is_out_aligned) { constexpr int kAlignmentC = is_out_aligned ? 8 : 1; dispatch_bool(env::enable_tf32(), [&](auto kEnableTF32) { fun = grouped_gemm_v2< GemmConfiguration>; }); }); }); }); return fun; } } // namespace void cutlass_grouped_gemm_unaligned( bool a_transposed, int lda, bool b_transposed, int ldb, int group_count, const array& a, const array& b, const array& indices, array& out, cu::CommandEncoder& encoder) { int K = a.shape(-1); int N = b.shape(-1); // Prepare device pointers for matmul. int problem_sizes_nbytes = group_count * cuda::ceil_div(sizeof(ProblemSize), 8) * 8; int nbytes = problem_sizes_nbytes + group_count * (3 * sizeof(void*) + 3 * sizeof(int64_t)); nbytes = cuda::ceil_div(nbytes, 256) * 256; array gemm_args(cu::malloc_async(nbytes, encoder), {nbytes}, int8); encoder.add_temporary(gemm_args); ProblemSize* problem_sizes = gpu_ptr(gemm_args); int64_t* a_lds = gpu_ptr(gemm_args) + problem_sizes_nbytes / 8; int64_t* b_lds = a_lds + group_count; int64_t* out_lds = b_lds + group_count; void** a_ptrs = reinterpret_cast(out_lds + group_count); void** b_ptrs = a_ptrs + group_count; void** out_ptrs = b_ptrs + group_count; // Fill the pointers by computing offsets from indices. constexpr int N_READS = 4; int n_threads = cuda::ceil_div(indices.size(), N_READS); n_threads = group_count < n_threads ? n_threads : group_count; dim3 block_dims(std::min(n_threads, 1024)); dim3 num_blocks(1); encoder.set_input_array(indices); encoder.set_output_array(gemm_args); encoder.add_kernel_node_ex( cu::prepare_grouped_mm_data, num_blocks, block_dims, {}, group_count * sizeof(uint32_t), // sizeof(cum_histo) gpu_ptr(indices), indices.size(), group_count, K, N, lda, ldb, out.itemsize(), gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), a.shape(-2) * a.shape(-1), // a_batch_stride b.shape(-2) * b.shape(-1), // b_batch_stride out.shape(-2) * out.shape(-1), // out_batch_stride problem_sizes, a_lds, b_lds, out_lds, a_ptrs, b_ptrs, out_ptrs); // Invoke grouped GEMM. encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(gemm_args); encoder.set_output_array(out); auto* fun = get_grouped_mm_funcion(a.dtype(), N, encoder.device()); fun(a_transposed, b_transposed, group_count, problem_sizes, a_lds, b_lds, out_lds, a_ptrs, b_ptrs, out_ptrs, encoder); } void cutlass_segmented_mm( bool a_transposed, int lda, bool b_transposed, int ldb, int num_segments, int M, int N, const array& a, const array& b, const array& segments, array& out, cu::CommandEncoder& encoder) { // Allocate grouped GEMM args on device. int problem_sizes_nbytes = num_segments * cuda::ceil_div(sizeof(ProblemSize), 8) * 8; int nbytes = problem_sizes_nbytes + num_segments * (3 * sizeof(void*) + 3 * sizeof(int64_t)); nbytes = cuda::ceil_div(nbytes, 256) * 256; array gemm_args(cu::malloc_async(nbytes, encoder), {nbytes}, int8); encoder.add_temporary(gemm_args); ProblemSize* problem_sizes = gpu_ptr(gemm_args); int64_t* a_lds = gpu_ptr(gemm_args) + problem_sizes_nbytes / 8; int64_t* b_lds = a_lds + num_segments; int64_t* out_lds = b_lds + num_segments; void** a_ptrs = reinterpret_cast(out_lds + num_segments); void** b_ptrs = a_ptrs + num_segments; void** out_ptrs = b_ptrs + num_segments; // Build problem descriptions from segments on the GPU. int block_size = std::min(num_segments, 256); int num_blocks = cuda::ceil_div(num_segments, block_size); encoder.set_input_array(segments); encoder.set_output_array(gemm_args); encoder.add_kernel_node_ex( cu::prepare_segmented_mm_data, dim3(num_blocks), dim3(block_size), {}, 0, gpu_ptr(segments), num_segments, M, N, static_cast(lda), static_cast(ldb), static_cast(out.itemsize()), a_transposed, b_transposed, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), problem_sizes, a_lds, b_lds, out_lds, a_ptrs, b_ptrs, out_ptrs); // Dispatch grouped GEMM. encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(gemm_args); encoder.set_output_array(out); auto* fun = get_grouped_mm_funcion(a.dtype(), N, encoder.device()); fun(a_transposed, b_transposed, num_segments, problem_sizes, a_lds, b_lds, out_lds, a_ptrs, b_ptrs, out_ptrs, encoder); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/hadamard.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/hadamard.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include #include #include #include namespace mlx::core { namespace { constexpr int MAX_HADAMARD_THREADS_PER_BLOCK = 256; std::string gen_hadamard_codelet(int m) { std::ostringstream source; source << "namespace mlx::core::cu {\n"; source << "__device__ __forceinline__ void hadamard_radix_m(float* x) {\n"; if (m == 1) { source << "}\n"; source << "} // namespace mlx::core::cu\n"; return source.str(); } auto h_matrices = hadamard_matrices(); auto it = h_matrices.find(m); if (it == h_matrices.end()) { throw std::runtime_error("[hadamard] Invalid radix m."); } auto& matrix = it->second; source << " float tmp[" << m << "];\n"; auto start = 1; auto end = matrix.find('\n', start); int row_idx = 0; while (end != std::string_view::npos) { auto row = matrix.substr(start, end - start); source << " tmp[" << row_idx << "] ="; for (int i = 0; i < row.length(); ++i) { source << " " << row[i] << " x[" << i << "]"; } source << ";\n"; start = end + 1; end = matrix.find('\n', start); row_idx++; } source << " #pragma unroll\n"; source << " for (int i = 0; i < " << m << "; ++i) { x[i] = tmp[i]; }\n"; source << "}\n"; source << "} // namespace mlx::core::cu\n"; return source.str(); } std::string hadamard_n_kernel_name( const Dtype& dtype, int n, int max_radix, int read_width, int stride) { return fmt::format( "mlx::core::cu::hadamard_n<{}, {}, {}, {}, {}>", dtype_to_cuda_type(dtype), n, max_radix, read_width, stride); } std::string hadamard_m_kernel_name(const Dtype& dtype, int n, int m, int read_width) { return fmt::format( "mlx::core::cu::hadamard_m<{}, {}, {}, {}>", dtype_to_cuda_type(dtype), n, m, read_width); } void hadamard_mn_contiguous( const array& x, array& y, int m, int n1, int n2, float scale, const Stream& s) { const int n = n1 * n2; const int read_width_n1 = (n1 == 2) ? 2 : 4; const int read_width_n2 = (n2 == 2) ? 2 : 4; const int read_width_m = (n == 2 || m == 28) ? 2 : 4; const int max_radix_1 = std::min(n1, 16); const int max_radix_2 = std::min(n2, 16); const float scale_n1 = 1.0f; const float scale_n2 = (m == 1) ? scale : 1.0f; const float scale_m = scale; const std::string n1_kernel_name = hadamard_n_kernel_name(x.dtype(), n1, max_radix_1, read_width_n1, n2); const std::string n2_kernel_name = hadamard_n_kernel_name(x.dtype(), n2, max_radix_2, read_width_n2, 1); const std::string m_kernel_name = hadamard_m_kernel_name(x.dtype(), n, m, read_width_m); const std::string module_name = fmt::format( "hadamard_{}_{}_{}_{}_{}_{}_{}_{}", dtype_to_string(x.dtype()), n, m, n1, n2, read_width_n1, read_width_n2, read_width_m); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names = {n2_kernel_name}; if (n1 > 1) { kernel_names.push_back(n1_kernel_name); } if (m > 1) { kernel_names.push_back(m_kernel_name); } std::string source = R"( #include "mlx/backend/cuda/device/utils.cuh" )"; source += gen_hadamard_codelet(m); source += R"( #include "mlx/backend/cuda/device/hadamard.cuh" )"; return std::make_tuple(false, std::move(source), std::move(kernel_names)); }); auto& encoder = cu::get_command_encoder(s); if (n1 > 1) { const int64_t num_transforms = x.size() / n1; const uint32_t num_blocks = static_cast(std::min(num_transforms, 65535)); encoder.set_input_array(x); encoder.set_output_array(y); cu::KernelArgs args; args.append(x); args.append(y); args.append(scale_n1); args.append(num_transforms); auto kernel = mod.get_kernel(n1_kernel_name); encoder.add_kernel_node_raw( kernel, num_blocks, n1 / max_radix_1, {}, 0, args.args()); } { const auto& in = (n1 > 1) ? y : x; const int64_t num_transforms = x.size() / n2; const uint32_t num_blocks = static_cast(std::min(num_transforms, 65535)); encoder.set_input_array(in); encoder.set_output_array(y); cu::KernelArgs args; args.append(in); args.append(y); args.append(scale_n2); args.append(num_transforms); auto kernel = mod.get_kernel(n2_kernel_name); encoder.add_kernel_node_raw( kernel, num_blocks, n2 / max_radix_2, {}, 0, args.args()); } if (m > 1) { const int64_t num_tasks = x.size() / (m * read_width_m); const uint32_t block_dim = static_cast( std::min(num_tasks, MAX_HADAMARD_THREADS_PER_BLOCK)); const uint32_t num_blocks = static_cast( std::min((num_tasks + block_dim - 1) / block_dim, 65535)); encoder.set_input_array(y); encoder.set_output_array(y); cu::KernelArgs args; args.append(y); args.append(y); args.append(scale_m); args.append(num_tasks); auto kernel = mod.get_kernel(m_kernel_name); encoder.add_kernel_node_raw( kernel, num_blocks, block_dim, {}, 0, args.args()); } } } // namespace void Hadamard::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Hadamard::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() != float16 && in.dtype() != bfloat16 && in.dtype() != float32) { throw std::invalid_argument("[hadamard] Unsupported type."); } // n = m * 2^k where m in (1, 12, 20, 28) auto [n, m] = decompose_hadamard(in.shape().back()); int n1 = 1; int n2 = n; if (n > 8192) { for (n2 = 2; n2 * n2 < n; n2 *= 2) { } n1 = n / n2; } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); if (in.flags().row_contiguous) { if (in.is_donatable()) { out.copy_shared_buffer(in); } else { out.set_data(cu::malloc_async(out.nbytes(), encoder)); } hadamard_mn_contiguous(in, out, m, n1, n2, scale_, s); } else { copy_gpu(in, out, CopyType::General, s); hadamard_mn_contiguous(out, out, m, n1, n2, scale_, s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/indexing.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include "cuda_jit_sources.h" #include #include #include #include #include #include namespace mlx::core { namespace { constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; constexpr const char* g_slice_ops[] = {"Maximum", "Minimum", "Add", "Multiply", ""}; void append_indices_arg( cu::KernelArgs& args, const std::vector& inputs, int nidx, int idx_ndim) { SmallVector indices(nidx); for (int i = 0; i < nidx; ++i) { indices[i] = gpu_ptr(inputs[i + 1]); } args.append(std::move(indices)); SmallVector indices_shape(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( inputs[i + 1].shape().begin(), idx_ndim, indices_shape.data() + i * idx_ndim); } args.append(std::move(indices_shape)); SmallVector indices_strides(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( inputs[i + 1].strides().begin(), idx_ndim, indices_strides.data() + i * idx_ndim); } args.append(std::move(indices_strides)); } } // namespace void Gather::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 0); const auto& src = inputs[0]; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (out.size() == 0) { return; } int nidx = inputs.size() - 1; Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || (src.size() > INT32_MAX) || (out.size() > INT32_MAX); uint32_t slice_size = std::accumulate( slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); std::string module_name = fmt::format( "gather_{}_{}_{}", dtype_to_string(out.dtype()), dtype_to_string(idx_dtype), nidx); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int large = 0; large <= 1; ++large) { kernel_names.push_back( fmt::format( "mlx::core::cu::gather<{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx_dtype), nidx, ndim, large ? "int64_t" : "int32_t")); } } return std::make_tuple(false, jit_source_gather, std::move(kernel_names)); }); cu::KernelArgs args; args.append(src); args.append(out); if (large) { args.append(out.size()); } else { args.append(out.size()); } args.append_ndim(src.shape()); args.append_ndim(src.strides()); args.append(src.ndim()); args.append_ndim(slice_sizes_); args.append(slice_size); args.append(axes_); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::gather<{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx_dtype), nidx, idx_ndim, large ? "int64_t" : "int32_t"); for (const auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(out, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 1); auto& upd = inputs.back(); // Copy src into out. CopyType copy_type; if (inputs[0].data_size() == 1) { copy_type = CopyType::Scalar; } else if (inputs[0].flags().row_contiguous) { copy_type = CopyType::Vector; } else { copy_type = CopyType::General; } copy_gpu(inputs[0], out, copy_type); // Empty update. if (upd.size() == 0) { return; } int nidx = axes_.size(); Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || (upd.size() > INT32_MAX) || (out.size() > INT32_MAX); int32_t upd_post_idx_size = std::accumulate( upd.shape().begin() + idx_ndim, upd.shape().end(), 1, std::multiplies()); const char* op = g_scatter_ops[reduce_type_]; std::string module_name = fmt::format( "scatter_{}_{}_{}_{}", dtype_to_string(out.dtype()), dtype_to_string(idx_dtype), op, nidx); auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int large = 0; large <= 1; ++large) { kernel_names.push_back( fmt::format( "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx_dtype), op, nidx, ndim, large ? "int64_t" : "int32_t")); } } return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); cu::KernelArgs args; args.append(upd); args.append(out); if (large) { args.append(upd.size()); } else { args.append(upd.size()); } args.append_ndim(upd.shape()); args.append_ndim(upd.strides()); args.append(upd.ndim()); if (large) { args.append(upd_post_idx_size); } else { args.append(upd_post_idx_size); } args.append_ndim(out.shape()); args.append_ndim(out.strides()); args.append(out.ndim()); args.append(axes_); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx_dtype), op, nidx, idx_ndim, large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(upd, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("GatherAxis::eval_gpu"); assert(inputs.size() > 1); const auto& src = inputs[0]; const auto& idx = inputs[1]; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (out.size() == 0) { return; } bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string module_name = fmt::format( "gather_axis_{}_{}", dtype_to_string(out.dtype()), dtype_to_string(idx.dtype())); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int contiguous = 0; contiguous < 4; ++contiguous) { for (int large = 0; large <= 1; ++large) { kernel_names.push_back( fmt::format( "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx.dtype()), ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, large ? "int64_t" : "int32_t")); } } } return std::make_tuple( false, jit_source_gather_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; size_t idx_size_post = 1; for (int i = 0; i < axis_; ++i) { idx_size_pre *= idx.shape(i); } for (int i = axis_ + 1; i < idx.ndim(); ++i) { idx_size_post *= idx.shape(i); } size_t idx_size_axis = idx.shape(axis_); cu::KernelArgs args; args.append(src); args.append(idx); args.append(out); if (large) { args.append(idx_size_pre); args.append(idx_size_axis); args.append(idx_size_post); } else { args.append(idx_size_pre); args.append(idx_size_axis); args.append(idx_size_post); } args.append(remove_index(idx.shape(), axis_)); args.append(remove_index(src.strides(), axis_)); args.append(remove_index(idx.strides(), axis_)); args.append(axis_); args.append(src.shape(axis_)); args.append(src.strides(axis_)); args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx.dtype()), src.ndim() - 1, src.flags().row_contiguous, idx.flags().row_contiguous, large ? "int64_t" : "int32_t"); for (const auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ScatterAxis::eval_gpu"); assert(inputs.size() > 2); const auto& src = inputs[0]; const auto& idx = inputs[1]; const auto& upd = inputs[2]; // Copy src into out. CopyType copy_type; if (src.data_size() == 1) { copy_type = CopyType::Scalar; } else if (src.flags().row_contiguous) { copy_type = CopyType::Vector; } else { copy_type = CopyType::General; } copy_gpu(src, out, copy_type); // Empty update. if (upd.size() == 0) { return; } bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; std::string module_name = fmt::format( "scatter_axis_{}_{}_{}", dtype_to_string(out.dtype()), dtype_to_string(idx.dtype()), op); auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int contiguous = 0; contiguous < 4; ++contiguous) { for (int large = 0; large <= 1; ++large) { kernel_names.push_back( fmt::format( "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx.dtype()), op, ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, large ? "int64_t" : "int32_t")); } } } return std::make_tuple( false, jit_source_scatter_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; size_t idx_size_post = 1; for (int i = 0; i < axis_; ++i) { idx_size_pre *= idx.shape(i); } for (int i = axis_ + 1; i < idx.ndim(); ++i) { idx_size_post *= idx.shape(i); } size_t idx_size_axis = idx.shape(axis_); cu::KernelArgs args; args.append(upd); args.append(idx); args.append(out); if (large) { args.append(idx_size_pre); args.append(idx_size_axis); args.append(idx_size_post); } else { args.append(idx_size_pre); args.append(idx_size_axis); args.append(idx_size_post); } args.append(remove_index(idx.shape(), axis_)); args.append(remove_index(upd.strides(), axis_)); args.append(remove_index(idx.strides(), axis_)); args.append(axis_); args.append(out.shape(axis_)); args.append(upd.strides(axis_)); args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), dtype_to_cuda_type(idx.dtype()), op, idx.ndim() - 1, upd.flags().row_contiguous, idx.flags().row_contiguous, large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("MaskedScatter::eval_gpu"); assert(inputs.size() == 3); const array& dst = inputs[0]; const array& mask = inputs[1]; const array& src = inputs[2]; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); const size_t total = mask.size(); out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (total == 0) { return; } array mask_flat = flatten_in_eval(mask, 1, -1, s); if (mask_flat.data() != mask.data()) { encoder.add_temporary(mask_flat); } if (!mask_flat.flags().row_contiguous) { mask_flat = contiguous_copy_gpu(mask_flat, s); encoder.add_temporary(mask_flat); } array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder)); encoder.add_temporary(scatter_offsets); scan_gpu_inplace( mask_flat, scatter_offsets, Scan::Sum, /* axis= */ 1, /* reverse= */ false, /* inclusive= */ false, s); const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask_flat.size() / batch_count; const size_t src_batch_size = src.size() / src.shape(0); bool large = total > INT32_MAX || src.size() > INT32_MAX; bool vectorized = src.flags().row_contiguous && dst.flags().row_contiguous; constexpr int kMaskedScatterVecSize = 16; constexpr int kMaskedScatterVecBlockDim = 256; std::string module_name = fmt::format("masked_scatter_{}", dtype_to_string(out.dtype())); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) { for (int use_large = 0; use_large <= 1; ++use_large) { kernel_names.push_back( fmt::format( "mlx::core::cu::masked_scatter<{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), src_contiguous ? "true" : "false", dst_contiguous ? "true" : "false", use_large ? "int64_t" : "int32_t")); } } } for (int use_large = 0; use_large <= 1; ++use_large) { kernel_names.push_back( fmt::format( "mlx::core::cu::masked_scatter_vec_contiguous<{}, {}, {}>", dtype_to_cuda_type(out.dtype()), use_large ? "int64_t" : "int32_t", kMaskedScatterVecSize)); } return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); cu::KernelArgs args; args.append(dst); args.append(mask_flat); args.append(scatter_offsets); args.append(src); args.append(out); if (large) { args.append(mask_flat.size()); args.append(src_batch_size); args.append(mask_batch_size); } else { args.append(mask_flat.size()); args.append(src_batch_size); args.append(mask_batch_size); } if (!vectorized) { args.append_ndim(dst.shape()); args.append_ndim(dst.strides()); args.append(dst.ndim()); args.append_ndim(src.shape()); args.append_ndim(src.strides()); args.append(src.ndim()); } encoder.set_input_array(dst); encoder.set_input_array(mask_flat); encoder.set_input_array(scatter_offsets); encoder.set_input_array(src); encoder.set_output_array(out); std::string kernel_name = vectorized ? fmt::format( "mlx::core::cu::masked_scatter_vec_contiguous<{}, {}, {}>", dtype_to_cuda_type(out.dtype()), large ? "int64_t" : "int32_t", kMaskedScatterVecSize) : fmt::format( "mlx::core::cu::masked_scatter<{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), src.flags().row_contiguous ? "true" : "false", dst.flags().row_contiguous ? "true" : "false", large ? "int64_t" : "int32_t"); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = vectorized ? get_launch_args( mask_flat, large, kMaskedScatterVecSize, kMaskedScatterVecBlockDim) : get_launch_args(mask_flat, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("SliceUpdate::eval_gpu"); assert(inputs.size() == 2); if (out.size() == 0) { return; } auto& in = inputs[0]; auto& upd = inputs[1]; if (upd.size() == 0) { out.copy_shared_buffer(in); return; } auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, out_strides] = prepare_slice(out, start_indices_, strides_); // Do copy for None reduce type if (reduce_type_ == SliceUpdate::None) { copy_gpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const Shape& data_shape = */ upd.shape(), /* const Strides& i_strides = */ upd.strides(), /* const Strides& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, /* CopyType ctype = */ CopyType::GeneralGeneral, /* const Stream& s = */ stream()); return; } auto [shape, strides] = collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); int nwork = 1; if (shape.back() % 4 == 0) { nwork = 4; } else if (shape.back() % 2 == 0) { nwork = 2; } const char* op_name = g_slice_ops[reduce_type_]; auto [ds, rc, cc] = check_contiguity(shape, strides[1]); bool upd_contiguous = upd.flags().row_contiguous; bool upd_scalar = upd.data_size() == 1; bool out_contiguous = rc; bool large = upd.size() > INT32_MAX; std::string module_name = fmt::format("slice_update_{}_{}", op_name, dtype_to_string(out.dtype())); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int out_c = 0; out_c <= 1; ++out_c) { for (int upd_c = 0; upd_c <= 1; ++upd_c) { for (int upd_s = 0; upd_s <= 1; ++upd_s) { for (int large = 0; large <= 1; ++large) { for (int nwork = 1; nwork <= 16; nwork *= 2) { kernel_names.push_back( fmt::format( "mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), large ? "int64_t" : "int32_t", op_name, out_c ? "true" : "false", upd_c ? "true" : "false", upd_s ? "true" : "false", nwork)); } } } } } return std::make_tuple( false, jit_source_slice_update, std::move(kernel_names)); }); cu::KernelArgs args; args.append(upd); args.append(out); args.append(upd.size()); args.append_ndim(shape); args.append_ndim(strides[0]); args.append(shape.size()); args.append_ndim(strides[1]); args.append(data_offset); encoder.set_input_array(upd); encoder.set_output_array(out); std::string kernel_name; kernel_name = fmt::format( "mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), large ? "int64_t" : "int32_t", op_name, out_contiguous, upd_contiguous, upd_scalar, nwork); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(upd, large, nwork); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/jit_module.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/device.h" #include "mlx/version.h" #include "cuda_jit_sources.h" #include #include #include #include #include namespace mlx::core::cu { namespace { #define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd)) void check_nvrtc_error(const char* name, nvrtcResult err) { if (err != NVRTC_SUCCESS) { throw std::runtime_error( fmt::format("{} failed: {}", name, nvrtcGetErrorString(err))); } } // Return the default path to CUDA toolkit. const std::filesystem::path& default_cuda_toolkit_path() { #if defined(_WIN32) static auto cached_path = []() -> std::filesystem::path { std::filesystem::path root( LR"(C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA)"); for (auto& file : std::filesystem::directory_iterator(root)) { if (std::filesystem::exists(file.path() / "include" / "cuda.h")) { return file.path(); } } return {}; }(); #else static std::filesystem::path cached_path = "/usr/local/cuda"; #endif return cached_path; } // Return the --include-path args used for invoking NVRTC. const std::vector& include_path_args() { static std::vector cached_args = []() { std::vector args; // Add path to bundled CCCL headers. auto root_dir = current_binary_dir(); #if !defined(_WIN32) root_dir = root_dir.parent_path(); #endif auto path = root_dir / "include" / "cccl"; #if defined(MLX_CCCL_DIR) if (!std::filesystem::exists(path)) { path = MLX_CCCL_DIR; } #endif if (std::filesystem::exists(path)) { args.push_back(fmt::format("--include-path={}", path.string())); } // Add path to CUDA runtime headers, try local-installed python package // first and then system-installed headers. path = root_dir.parent_path() / "nvidia" / "cuda_runtime" / "include"; if (!std::filesystem::exists(path)) { const char* home = std::getenv("CUDA_HOME"); if (!home) { home = std::getenv("CUDA_PATH"); } path = home ? std::filesystem::path(home) : default_cuda_toolkit_path(); if (!path.empty()) { path = path / "include"; } if (path.empty() || !std::filesystem::exists(path)) { throw std::runtime_error( "Can not find locations of CUDA headers, please set environment " "variable CUDA_HOME or CUDA_PATH."); } } args.push_back(fmt::format("--include-path={}", path.string())); return args; }(); return cached_args; } // Get the cache directory for storing compiled results. const std::filesystem::path& ptx_cache_dir() { static std::filesystem::path cache = []() -> std::filesystem::path { std::filesystem::path cache; if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) { cache = c; } else { cache = std::filesystem::temp_directory_path() / "mlx" / version() / "ptx"; } #if defined(_WIN32) // Add "\\?\" prefix to support long file path. const wchar_t* long_path_prefix = L"\\\\?\\"; if (cache.is_relative()) { cache = std::filesystem::absolute(cache); } if (!cache.native().starts_with(long_path_prefix)) { cache = long_path_prefix + cache.native(); } #endif if (!std::filesystem::exists(cache)) { std::error_code error; if (!std::filesystem::create_directories(cache, error)) { return std::filesystem::path(); } } return cache; }(); return cache; } std::filesystem::path get_ptx_path( const std::filesystem::path& cache_dir, const std::string& module_name) { constexpr int max_file_name_length = 245; if (module_name.size() <= max_file_name_length) { return cache_dir / (module_name + ".ptx"); } auto ptx_path = cache_dir; int offset = 0; while (module_name.size() - offset > max_file_name_length) { ptx_path /= module_name.substr(offset, max_file_name_length); offset += max_file_name_length; } ptx_path /= module_name.substr(offset) + ".ptx"; return ptx_path; } // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. bool read_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, std::string& ptx, std::vector>& ptx_kernels) { if (cache_dir.empty()) { return false; } auto ptx_path = get_ptx_path(cache_dir, module_name); std::error_code error; auto ptx_size = std::filesystem::file_size(ptx_path, error); if (error) { return false; } std::ifstream ptx_file(ptx_path, std::ios::binary); if (!ptx_file.good()) { return false; } ptx.resize(ptx_size); ptx_file.read(ptx.data(), ptx_size); std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); std::string line; while (std::getline(txt_file, line)) { auto tab = line.find('\t'); if (tab != std::string::npos) { ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1)); } } return true; } // Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|. void write_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, const std::string& ptx, const std::vector>& ptx_kernels, const std::string& source_code) { if (cache_dir.empty()) { return; } auto ptx_path = get_ptx_path(cache_dir, module_name); // Ensure that the directory exists auto parent = ptx_path.parent_path(); if (parent != cache_dir) { std::filesystem::create_directories(parent); } // Write the compiled code and mangled names std::ofstream ptx_file(ptx_path, std::ios::binary); if (!ptx.empty()) { ptx_file.write(&ptx.front(), ptx.size()); } std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); for (const auto& [name, mangled] : ptx_kernels) { txt_file << name << "\t" << mangled << std::endl; } // Write the generated code std::ofstream source_file(ptx_path.replace_extension(".cu")); source_file << source_code; } // Return if |device|'s version is not newer than |major|.|minor| version. inline bool version_lower_equal(Device& device, int major, int minor) { if (device.compute_capability_major() < major) { return true; } else if (device.compute_capability_major() == major) { return device.compute_capability_minor() <= minor; } else { return false; } } // Return whether NVRTC supports compiling to |device|'s SASS code. bool compiler_supports_device_sass(Device& device) { int nvrtc_major, nvrtc_minor; CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor)); if (nvrtc_major < 9) { return false; } else if (nvrtc_major == 9) { return version_lower_equal(device, 7, 2); } else if (nvrtc_major == 10) { return version_lower_equal(device, 7, 5); } else if (nvrtc_major == 11 && nvrtc_minor == 0) { return version_lower_equal(device, 8, 0); } else if (nvrtc_major == 11 && nvrtc_minor < 8) { return version_lower_equal(device, 8, 6); } else { return true; } } #define INCLUDE_PREFIX "mlx/backend/cuda/device/" constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "complex.cuh", INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "hadamard.cuh", INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", }; #undef INCLUDE_PREFIX constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, jit_source_config, jit_source_complex, jit_source_fp16_math, jit_source_hadamard, jit_source_indexing, jit_source_scatter_ops, jit_source_unary_ops, jit_source_ternary_ops, jit_source_utils, }; void compile( Device& device, const std::string& module_name, const std::string& source, const std::vector& kernel_names, std::string& ptx, std::vector>& ptx_kernels) { // Create the program nvrtcProgram prog; CHECK_NVRTC_ERROR(nvrtcCreateProgram( &prog, source.c_str(), (module_name + ".cu").c_str(), std::size(g_headers), g_headers, g_include_names)); std::unique_ptr prog_freer( &prog, [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); for (const auto& name : kernel_names) { CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); } // Compile program. std::vector args; bool use_sass = compiler_supports_device_sass(device); auto cc = device.compute_capability_major(); std::string arch_tag = (cc >= 9) ? "a" : ""; std::string compute = fmt::format( "--gpu-architecture={}_{}{}{}", use_sass ? "sm" : "compute", cc, device.compute_capability_minor(), arch_tag); args.push_back(compute.c_str()); for (const auto& include : include_path_args()) { args.push_back(include.c_str()); } nvrtcResult compile_result = nvrtcCompileProgram(prog, args.size(), args.data()); if (compile_result != NVRTC_SUCCESS) { size_t log_size; CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); std::vector log(log_size + 1, 0); CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); throw std::runtime_error( fmt::format("Failed to compile kernel: {}.", log.data())); } // Get mangled names of kernel names. for (const auto& name : kernel_names) { const char* mangled; CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); ptx_kernels.emplace_back(name, mangled); } // Get ptx data. size_t ptx_size; if (use_sass) { CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); } else { CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); } ptx.resize(ptx_size); if (use_sass) { CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); } else { CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); } } void load_module( const std::string& module_name, const std::string& ptx, const std::vector>& ptx_kernels, CUmodule& module_, std::unordered_map>& kernels) { // Load module. char jit_log[4089] = {}; CUjit_option options[] = { CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; CUresult jit_result = cuModuleLoadDataEx( &module_, ptx.data(), std::size(options), options, values); if (jit_result != CUDA_SUCCESS) { throw std::runtime_error( fmt::format( "Failed to load compiled {} kernel: {}.", module_name, jit_log)); } // Load kernels. for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); kernels[name] = std::make_tuple(kernel, false, 0); } } } // namespace JitModule::JitModule( Device& device, const std::string& module_name, const KernelBuilder& builder, bool use_disk_cache) { // Will hold the actual device executable source code and kernel names std::string ptx; std::vector> ptx_kernels; // Try to load them from the file cache if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) { auto [precompiled, source_code, kernel_names] = builder(); // Get the PTX or cubin if (precompiled) { ptx = std::move(source_code); for (auto& name : kernel_names) { ptx_kernels.emplace_back(name, name); } } else { compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels); } // If requested save them in the file cache for the next launch if (use_disk_cache) { write_cached_ptx( ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); } } // Load the module load_module(module_name, ptx, ptx_kernels, module_, kernels_); } JitModule::~JitModule() { CHECK_CUDA_ERROR(cuModuleUnload(module_)); } std::pair JitModule::get_kernel_and_dims( const std::string& kernel_name, std::function configure_kernel) { auto it = kernels_.find(kernel_name); if (it == kernels_.end()) { throw std::runtime_error( fmt::format("There is no kernel named {}.", kernel_name)); } // If it is the first time we run this kernel then configure it. Do it only // once! auto kernel = std::get<0>(it->second); if (!std::get<1>(it->second)) { if (configure_kernel) { configure_kernel(kernel); } std::get<1>(it->second) = true; std::get<2>(it->second) = max_occupancy_block_dim(kernel); } return {kernel, std::get<2>(it->second)}; } CUfunction JitModule::get_kernel( const std::string& kernel_name, std::function configure_kernel) { return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first; } std::unordered_map& get_jit_module_cache() { static std::unordered_map map; return map; } JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, const KernelBuilder& builder, bool cache) { auto& map = get_jit_module_cache(); auto it = map.find(name); if (it == map.end()) { it = map.try_emplace(name, cu::device(device), name, builder, cache).first; } return it->second; } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/jit_module.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include #include #include #include #include #include namespace mlx::core::cu { class Device; using KernelBuilderResult = std::tuple< /* precompiled */ bool, /* source code */ std::string, /* kernel names */ std::vector>; using KernelBuilder = std::function; struct KernelArgs { void** args() { return args_.data(); } void append(const array& a) { append(reinterpret_cast(gpu_ptr(a))); } template void append(T val) { storage_.emplace_back(val); append_ptr(&storage_.back()); } template void append(SmallVector vec) { storage_.emplace_back(std::move(vec)); append_ptr(std::get>(storage_.back()).data()); } template void append(const std::vector& vec) { append(SmallVector(vec.begin(), vec.end())); } // Make sure the arg is copied to an array with size of NDIM. template void append_ndim(SmallVector vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } vec.resize(NDIM); append(std::move(vec)); } void append_ptr(const void* v) { args_.push_back(const_cast(v)); } private: std::vector args_; // The cuGraphAddKernelNode API requires passing pointers to arguments so // store temporary values until the node is created. using Arg = std::variant< std::monostate, CUdeviceptr, bool, int32_t, uint32_t, int64_t, float, SmallVector, SmallVector, SmallVector>; std::deque storage_; }; class JitModule { public: JitModule( Device& device, const std::string& module_name, const KernelBuilder& builder, bool cache); ~JitModule(); JitModule(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete; CUfunction get_kernel( const std::string& kernel_name, std::function configure_kernel = nullptr); std::pair get_kernel_and_dims( const std::string& kernel_name, std::function configure_kernel = nullptr); private: CUmodule module_{nullptr}; std::unordered_map> kernels_; }; std::unordered_map& get_jit_module_cache(); JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, const KernelBuilder& builder, bool use_disk_cache = true); } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/kernel_utils.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/kernel_utils.cuh" namespace mlx::core { dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { Dims dims = get_2d_grid_dims_common(shape, strides); return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor) { Dims dims = get_2d_grid_dims_common(shape, strides, divisor); return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } std::pair get_grid_and_block(int dim0, int dim1, int dim2) { auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2); auto [gx, gy, gz] = grid; auto [bx, by, bz] = block; return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); } std::tuple get_launch_args( size_t size, const Shape& shape, const Strides& strides, bool large, int work_per_thread /* = 1 */, uint32_t max_block_dim /* = 1024 */) { size_t nthreads = cuda::ceil_div(size, work_per_thread); uint32_t block_dim = max_block_dim < nthreads ? max_block_dim : nthreads; dim3 num_blocks; if (large) { num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); } else { num_blocks.x = cuda::ceil_div(nthreads, block_dim); } return std::make_tuple(num_blocks, block_dim); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/kernel_utils.cuh ================================================ // Copyright © 2025 Apple Inc. // This file includes host-only utilities for writing CUDA kernels, the // difference from backend/cuda/device/utils.cuh is that the latter file only // include device-only code. #pragma once #include #include "mlx/array.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include #include #include #include namespace mlx::core { template void dispatch_1_2_3(int n, F&& f) { switch (n) { case 1: f(std::integral_constant{}); break; case 2: f(std::integral_constant{}); break; case 3: f(std::integral_constant{}); break; } } template void dispatch_bool(bool v, F&& f) { if (v) { f(std::true_type{}); } else { f(std::false_type{}); } } template void dispatch_block_dim(int threads, F&& f) { if (threads <= WARP_SIZE) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 2) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 4) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 8) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 16) { f(std::integral_constant{}); } else { f(std::integral_constant{}); } } // Maps CPU types to CUDA types. template struct CTypeToCudaType { using type = T; }; template <> struct CTypeToCudaType { using type = __half; }; template <> struct CTypeToCudaType { using type = __nv_bfloat16; }; template <> struct CTypeToCudaType { using type = cu::complex64_t; }; template using cuda_type_t = typename CTypeToCudaType::type; // Type traits for detecting floating numbers. template inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex numbers. template inline constexpr bool is_complex_v = cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex or real floating point numbers. template inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const SmallVector& vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } cuda::std::array result; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } // Compute the grid and block dimensions, check backend/common/utils.h for docs. dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor); std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Get the num_blocks and block_dims assuming each thread handles // |work_per_thread| elements of |arr|. std::tuple get_launch_args( size_t size, const Shape& shape, const Strides& strides, bool large, int work_per_thread = 1, uint32_t max_block_dim = 1024); inline std::tuple get_launch_args( const array& arr, bool large, int work_per_thread = 1, uint32_t max_block_dim = 1024) { return get_launch_args( arr.size(), arr.shape(), arr.strides(), large, work_per_thread, max_block_dim); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/layer_norm.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; inline __device__ float3 plus_f3(const float3& a, const float3& b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } // Similar to cub::BlockReduce, but result is broadcasted to every thread. template struct BlockBroadcastReduce { static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); static_assert(BLOCK_DIM % WARP_SIZE == 0); using TempStorage = T[BLOCK_DIM / WARP_SIZE]; cg::thread_block& block; TempStorage& temp; template __device__ T Reduce(const T& input, const Op& op, const T& init_value) { auto warp = cg::tiled_partition(block); T x = cg::reduce(warp, input, op); if (warp.thread_rank() == 0) { temp[warp.meta_group_rank()] = x; } block.sync(); x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] : init_value; return cg::reduce(warp, x, op); } __device__ T Sum(const T& input) { return Reduce(input, cg::plus{}, T{}); } }; template __global__ void layer_norm( const T* x, const T* w, const T* b, T* out, float eps, int32_t axis_size, int64_t w_stride, int64_t b_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceT = BlockBroadcastReduce; __shared__ typename BlockReduceT::TempStorage temp; x += grid.block_rank() * axis_size; out += grid.block_rank() * axis_size; // Sum. float sum = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { sum += static_cast(xn[i]); } } sum = BlockReduceT{block, temp}.Sum(sum); // Mean. float mean = sum / axis_size; // Normalizer. float normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); if ((index + 1) * N_READS <= axis_size) { auto xn = load_vector(x, index); #pragma unroll for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]) - mean; normalizer += t * t; } } else { for (int i = index * N_READS; i < axis_size; ++i) { float t = static_cast(x[i]) - mean; normalizer += t * t; } } } normalizer = BlockReduceT{block, temp}.Sum(normalizer); normalizer = rsqrt(normalizer / axis_size + eps); // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); auto bn = load_vector(b, index, axis_size, b_stride, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { float norm = (static_cast(xn[i]) - mean) * normalizer; xn[i] = wn[i] * static_cast(norm) + bn[i]; } store_vector(out, index, xn, axis_size); } } template __global__ void layer_norm_vjp( const T* x, const T* w, const T* g, T* gx, T* gw, float eps, int32_t axis_size, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceF = BlockBroadcastReduce; using BlockReduceF3 = BlockBroadcastReduce; __shared__ union { typename BlockReduceF::TempStorage f; typename BlockReduceF3::TempStorage f3; } temp; x += grid.block_rank() * axis_size; g += grid.block_rank() * axis_size; gx += grid.block_rank() * axis_size; gw += grid.block_rank() * axis_size; // Sum. float sum = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { sum += static_cast(xn[i]); } } sum = BlockReduceF{block, temp.f}.Sum(sum); // Mean. float mean = sum / axis_size; // Normalizer. float3 factors = {}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto gn = load_vector(g, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); if ((index + 1) * N_READS <= axis_size) { auto xn = load_vector(x, index); #pragma unroll for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]) - mean; float wi = wn[i]; float gi = gn[i]; float wg = wi * gi; factors = plus_f3(factors, {wg, wg * t, t * t}); } } else { for (int i = index * N_READS; i < axis_size; ++i) { float t = static_cast(x[i]) - mean; float wi = wn[i]; float gi = gn[i]; float wg = wi * gi; factors = plus_f3(factors, {wg, wg * t, t * t}); } } } factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); float meanwg = factors.x / axis_size; float meanwgxc = factors.y / axis_size; float normalizer2 = 1 / (factors.z / axis_size + eps); float normalizer = sqrt(normalizer2); // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); auto gn = load_vector(g, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float xi = (static_cast(xn[i]) - mean) * normalizer; float wi = wn[i]; float gi = gn[i]; xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2; if constexpr (HAS_W) { wn[i] = gi * xi; } } store_vector(gx, index, xn, axis_size); if constexpr (HAS_W) { store_vector(gw, index, wn, axis_size); } } } } // namespace cu namespace fast { bool LayerNorm::use_fallback(Stream s) { return s.device == Device::cpu; } // TODO: There are duplicate code with backend/metal/normalization.cpp void LayerNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("LayerNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( cu::malloc_async(x.data_size() * x.itemsize(), encoder), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; int32_t axis_size = x.shape().back(); int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(b); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::layer_norm; encoder.add_kernel_node( kernel, n_rows, block_dim(), gpu_ptr(x), gpu_ptr(w), gpu_ptr(b), gpu_ptr(out), eps_, axis_size, w_stride, b_stride); }); }); } void LayerNormVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("LayerNormVJP::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { copied = false; return x; } copied = true; return contiguous_copy_gpu(x, s); }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); bool copied; auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; const array& b = inputs[2]; bool g_copied; auto g = check_input(inputs[3], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; array& gb = outputs[2]; // Check whether we had a weight. bool has_w = w.ndim() != 0; // Allocate space for the outputs. bool g_in_gx = false; if (donate_x) { gx.copy_shared_buffer(x); } else if (donate_g) { gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(cu::malloc_async(gx.nbytes(), encoder)); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); } int32_t axis_size = x.shape().back(); int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; // Allocate a temporary to store the gradients for w and allocate the output // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; bool g_in_gw = false; if (has_w) { if (!g_in_gx && donate_g) { g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder)); encoder.add_temporary(gw_temp); } } // The gradient for b in case we had a b. bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); if (has_gb) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); } // Insert dependency if `g` was donated if ((g_in_gx || g_in_gw) && has_gb) { encoder.set_input_array(gb); } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::layer_norm_vjp< DataType, has_w_constant.value, block_dim(), N_READS>; encoder.add_kernel_node( kernel, n_rows, block_dim(), gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), gpu_ptr(gx), gpu_ptr(gw_temp), eps_, axis_size, w_stride); }); }); }); if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); } } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/load.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/utils.h" #include "mlx/primitives.h" namespace { template void swap_endianness(uint8_t* data_bytes, size_t N) { struct Elem { uint8_t bytes[scalar_size]; }; Elem* data = reinterpret_cast(data_bytes); for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < (scalar_size / 2); j++) { std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); } } } } // namespace namespace mlx::core { void Load::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(stream()); auto size = out.size(); auto nbytes = size * out.itemsize(); out.set_data(cu::malloc_async(nbytes, encoder)); auto out_ptr = malloc(nbytes); reader_->read(static_cast(out_ptr), nbytes, offset_); if (swap_endianness_) { switch (out.itemsize()) { case 2: swap_endianness<2>(reinterpret_cast(out_ptr), size); break; case 4: swap_endianness<4>(reinterpret_cast(out_ptr), size); break; case 8: swap_endianness<8>(reinterpret_cast(out_ptr), size); break; } } CHECK_CUDA_ERROR(cudaMemcpyAsync( gpu_ptr(out), out_ptr, nbytes, cudaMemcpyDefault, encoder.stream())); CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr)); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/logsumexp.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template inline __device__ T softmax_exp(T x) { // Softmax doesn't need high precision exponential cause x is gonna be in // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). return __expf(x); } template __global__ void logsumexp(const T* in, T* out, int axis_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); in += grid.block_rank() * axis_size; cg::greater max_op; cg::plus plus_op; // Thread reduce. AccT prevmax; AccT maxval = Limits::finite_min(); AccT normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { auto index = r * BLOCK_DIM + block.thread_rank(); auto vals = load_vector(in, index, axis_size, Limits::min()); prevmax = maxval; #pragma unroll for (int i = 0; i < N_READS; ++i) { maxval = max_op(maxval, static_cast(vals[i])); } // Online normalizer calculation for softmax: // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } // First warp reduce. prevmax = maxval; maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = cg::reduce(warp, normalizer, plus_op); __shared__ AccT local_max[WARP_SIZE]; __shared__ AccT local_normalizer[WARP_SIZE]; // Write to shared memory and do second warp reduce. prevmax = maxval; if (warp.thread_rank() == 0) { local_max[warp.meta_group_rank()] = maxval; } block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] : Limits::finite_min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { local_normalizer[warp.meta_group_rank()] = normalizer; } block.sync(); normalizer = warp.thread_rank() < warp.meta_group_size() ? local_normalizer[warp.thread_rank()] : AccT{}; normalizer = cg::reduce(warp, normalizer, plus_op); // Write output. if (block.thread_rank() == 0) { out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval; } } } // namespace cu void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("LogSumExp::eval_gpu"); assert(inputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. auto ensure_contiguous = [&s, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } }; auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); } else { auto n = in.shape(-1); auto flags = in.flags(); auto strides = in.strides(); for (auto& s : strides) { s /= n; } bool col_contig = strides[0] == 1; for (int i = 1; col_contig && i < strides.size(); ++i) { col_contig &= (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); } flags.col_contiguous = col_contig; out.set_data( cu::malloc_async(in.nbytes() / n, encoder), in.data_size() / n, std::move(strides), flags); } int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::logsumexp; encoder.add_kernel_node( kernel, n_rows, block_dim(), gpu_ptr(in), gpu_ptr(out), axis_size); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/lru_cache.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/utils.h" #include #include #include #include #include namespace mlx::core { template < typename K, typename V, template typename M = std::unordered_map> class LRUCache { public: using value_type = std::pair; using list_type = std::list; using iterator = typename list_type::iterator; using const_iterator = typename list_type::const_iterator; using map_type = M; explicit LRUCache(size_t capacity) : capacity_(capacity) { if (capacity == 0) { throw std::runtime_error("LRUCache requires capacity > 0."); } } // Initialize with capacity read from |env_name|. LRUCache(const char* env_name, int default_capacity) : LRUCache(env::get_var(env_name, default_capacity)) { if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) { env_name_ = env_name; } } size_t size() const { return map_.size(); } size_t capacity() const { return capacity_; } bool empty() const { return vlist_.empty(); } void resize(size_t new_capacity) { capacity_ = new_capacity; trim(); } iterator begin() { return vlist_.begin(); } const_iterator begin() const { return vlist_.begin(); } iterator end() { return vlist_.end(); } const_iterator end() const { return vlist_.end(); } void clear() { map_.clear(); vlist_.clear(); } iterator find(const K& key) { auto it = map_.find(key); if (it == map_.end()) return end(); vlist_.splice(vlist_.begin(), vlist_, it->second); return it->second; } template std::pair emplace(const K& key, U&& value) { auto it = map_.find(key); if (it != map_.end()) { vlist_.splice(vlist_.begin(), vlist_, it->second); return {it->second, false}; } if (env_name_ && ++cache_misses_ > 2 * capacity_) { throw std::runtime_error( fmt::format( "Cache thrashing is happening, please set the environment variable " "{} to a larger value than {} to fix degraded performance.", env_name_, capacity_)); } vlist_.emplace_front(key, std::forward(value)); map_[key] = vlist_.begin(); trim(); return {vlist_.begin(), true}; } iterator erase(iterator pos) { map_.erase(pos->first); return vlist_.erase(pos); } V& operator[](const K& key) { auto it = find(key); if (it == end()) { it = emplace(key, V{}).first; } return it->second; } private: void trim() { while (map_.size() > capacity_) { auto last = std::prev(vlist_.end()); map_.erase(last->first); vlist_.pop_back(); } } const char* env_name_{nullptr}; size_t cache_misses_{0}; list_type vlist_; map_type map_; size_t capacity_; }; // Turn a POD struct into a container key by doing bytes compare. // // IMPORTANT: Do not use aggregate init on the pod field (key.pod = {...}). // It creates a stack temporary whose padding bytes are uninitialized, and // trivial copy-assignment copies the entire struct including padding — // breaking the memcmp-based comparison. Set fields individually instead. // // Usage: // BytesKey key; // key.pod.field1 = value1; // key.pod.field2 = value2; template struct BytesKey { T pod; static_assert(std::is_standard_layout_v, "T is not POD"); BytesKey() { // Make sure the paddings between members are filled with 0. memset(&pod, 0, sizeof(T)); } BytesKey(const BytesKey& other) { memcpy(&pod, &other.pod, sizeof(T)); } BytesKey(BytesKey&& other) { memcpy(&pod, &other.pod, sizeof(T)); } bool operator==(const BytesKey& other) const { auto* ptr1 = reinterpret_cast(&pod); auto* ptr2 = reinterpret_cast(&other.pod); return memcmp(ptr1, ptr2, sizeof(T)) == 0; } }; // Compute hash according to the bytes value of T. template struct BytesHash { static_assert(std::is_standard_layout_v, "T is not POD"); size_t operator()(const T& pod) const { auto* ptr = reinterpret_cast(&pod); uint32_t value = 0x811C9DC5; for (int i = 0; i < sizeof(T); ++i) { value ^= ptr[i]; value *= 0x01000193; } return value; } }; template using BytesKeyHashMap = std::unordered_map>; template using LRUBytesKeyCache = LRUCache, V, BytesKeyHashMap>; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/matmul.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/grouped_gemm.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace { std::tuple check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && stx == arr.shape(-1)) { return std::make_tuple(false, stx, arr); } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { array arr_copy = contiguous_copy_gpu(arr, s); enc.add_temporary(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } } std::tuple ensure_batch_contiguous(const array& x, cu::CommandEncoder& encoder, Stream s) { if (x.flags().row_contiguous) { return std::make_tuple(false, x.strides(-2), x); } bool rc = true; for (int i = 0; i < x.ndim() - 3; i++) { rc &= (x.strides(i + 1) * x.shape(i)) == x.strides(i); } if (rc) { return check_transpose(encoder, s, x); } array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return std::make_tuple(false, x_copy.strides(-2), x_copy); } array ensure_row_contiguous( const array& x, cu::CommandEncoder& encoder, Stream s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } else { return x; } } void gemm_and_bias( cu::CommandEncoder& encoder, int M, int N, int K, bool a_transposed, int64_t lda, bool b_transposed, int64_t ldb, array& out, const array& a, const array& b, const std::optional& bias = std::nullopt, float alpha = 1.0f) { // Check and collapse batch dimensions auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); auto batch_count = out.size() / (M * N); // Collapse batches into M if needed if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && b_batch_strides.back() == 0) { M *= batch_shape.back(); batch_count = 1; a_batch_strides = {0}; b_batch_strides = {0}; batch_shape = {1}; } // Use gemmv when possible if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { cu::gemv( a, b, out, M, N, K, batch_count, batch_shape, a_batch_strides, b_batch_strides, encoder); return; } // Invoke cublasLt CublasGemm gemm( encoder.device(), a.dtype(), a_transposed, M, K, lda, b_transposed, K, N, ldb, batch_shape.back(), a_batch_strides.back(), b_batch_strides.back()); if (bias) { if (a.dtype() == complex64) { throw std::runtime_error( "[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul."); } gemm.set_bias(encoder, *bias); } gemm.run( encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); } void gather_mm_rhs( const array& a_, const array& b_, const array& indices_, array& out, cu::CommandEncoder& encoder, Stream s) { if (a_.size() / a_.shape(-2) / a_.shape(-1) != indices_.size()) { throw std::runtime_error("[gather_mm] Broadcasting lhs is not supported."); } int group_count = b_.size() / b_.shape(-1) / b_.shape(-2); if (group_count > 1024) { throw std::runtime_error( "[gather_mm] Group count can not be larger than 1024."); } auto [a_transposed, lda, a] = ensure_batch_contiguous(a_, encoder, s); auto [b_transposed, ldb, b] = ensure_batch_contiguous(b_, encoder, s); auto indices = ensure_row_contiguous(indices_, encoder, s); cutlass_grouped_gemm_unaligned( a_transposed, lda, b_transposed, ldb, group_count, a, b, indices, out, encoder); } } // namespace void Matmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Matmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 2); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; // Return 0s if either input is empty. if (a_pre.size() == 0 || b_pre.size() == 0) { array zero(0, a_pre.dtype()); encoder.add_temporary(zero); fill_gpu(zero, out, s); return; } out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); gemm_and_bias( encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("AddMM::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 3); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; auto c = inputs[2]; ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); ///////////////////////////////////////////////////////////////////////////// // Dispatch to GEMM with epilogue or AddMM if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); gemm_and_bias( encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b, c, alpha_); return; } int64_t ldc; { auto stx = c.strides()[c.ndim() - 2]; auto sty = c.strides()[c.ndim() - 1]; if (sty == 1 && stx == c.shape(-1)) { ldc = stx; out.set_data(cu::malloc_async(out.nbytes(), encoder)); } else if (sty == 1 && stx == 0) { ldc = 0; out.set_data(cu::malloc_async(out.nbytes(), encoder)); } else { // Copy C into out and set C to out ldc = c.shape(-1); copy_gpu(c, out, CopyType::General, s); c = out; } } ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] = collapse_batches(a, b, c); auto batch_count = out.size() / (M * N); // Collapse batches into M if needed if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && c_batch_strides.back() == M * c.strides()[c.ndim() - 2] && b_batch_strides.back() == 0) { M *= batch_shape.back(); batch_count = 1; a_batch_strides = {0}; b_batch_strides = {0}; c_batch_strides = {0}; batch_shape = {1}; } ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt with AddMM settings CublasGemm gemm( cu::device(s.device), a.dtype(), a_transposed, M, K, lda, b_transposed, K, N, ldb, ldc, batch_shape.back(), a_batch_strides.back(), b_batch_strides.back(), c_batch_strides.back()); gemm.run( encoder, out, a, b, c, batch_shape, a_batch_strides, b_batch_strides, c_batch_strides, alpha_, beta_); } void GatherMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("GatherMM::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 4); auto& a = inputs[0]; auto& b = inputs[1]; auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; // Return 0s if either input is empty. if (a.size() == 0 || b.size() == 0) { array zero(0, a.dtype()); encoder.add_temporary(zero); fill_gpu(zero, out, s); return; } out.set_data(cu::malloc_async(out.nbytes(), encoder)); // Extract shapes from inputs. int M = a.shape(-2); int N = b.shape(-1); int K = a.shape(-1); // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { gather_mm_rhs(a, b, rhs_indices, out, encoder, s); return; } auto [transposed_a, lda, a_] = check_transpose(encoder, s, a); auto [transposed_b, ldb, b_] = check_transpose(encoder, s, b); auto use_gemv = cu::can_use_gemv(M, N, K, transposed_a, transposed_b); if (M == 1 && use_gemv) { gather_mv(b_, a_, rhs_indices, lhs_indices, out, N, K, encoder); return; } if (N == 1 && use_gemv) { gather_mv(a_, b_, lhs_indices, rhs_indices, out, M, K, encoder); return; } throw std::runtime_error("NYI"); } void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("SegmentedMM::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 3); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; auto& segments_pre = inputs[2]; // Return zeros if output is empty or either input is empty. if (out.size() == 0 || a_pre.size() == 0 || b_pre.size() == 0) { array zero(0, a_pre.dtype()); encoder.add_temporary(zero); fill_gpu(zero, out, s); return; } out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = a_pre.shape(-2); int N = b_pre.shape(-1); int num_segments = segments_pre.size() / 2; auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto segments = [&] { if (segments_pre.flags().row_contiguous) { return segments_pre; } array copy = contiguous_copy_gpu(segments_pre, s); encoder.add_temporary(copy); return copy; }(); cutlass_segmented_mm( a_transposed, lda, b_transposed, ldb, num_segments, M, N, a, b, segments, out, encoder); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/no_cuda.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cuda.h" #include "mlx/fast.h" namespace mlx::core { namespace cu { bool is_available() { return false; } } // namespace cu namespace fast { CustomKernelFunction cuda_kernel( const std::string&, const std::vector&, const std::vector&, const std::string&, const std::string&, bool, int) { throw std::runtime_error("[cuda_kernel] No CUDA back-end."); } std::vector precompiled_cuda_kernel( const std::string&, const std::string&, const std::vector&, const std::vector&, const std::vector&, const std::vector&, std::tuple, std::tuple, int shared_memory, std::optional init_value, bool ensure_row_contiguous, StreamOrDevice) { throw std::runtime_error("[cuda_kernel] No CUDA back-end."); } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/primitives.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/distributed/primitives.h" #include #include "mlx/fast_primitives.h" #include "mlx/primitives.h" namespace mlx::core { #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } #define NO_GPU_USE_FALLBACK(func) \ bool func::use_fallback(Stream s) { \ return true; \ } \ NO_GPU_MULTI(func) #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } NO_GPU(BlockMaskedMM) NO_GPU(GatherQMM) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace distributed { NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) } // namespace distributed } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/affine_quantize.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/quantized.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr float eps = 1e-7; constexpr int simd_size = WARP_SIZE; constexpr float n_bins = (1 << bits) - 1; constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; size_t offset = tidx + grid_dim_x * size_t(tidy); size_t in_index = offset * values_per_reduce; if (in_index >= size) { return; } size_t out_index = power_of_2_bits ? offset * writes_per_pack : offset * bytes_per_pack / writes_per_reduce; float w_thread[values_per_reduce]; float w_min = Limits::max(); float w_max = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { float val = w[in_index + i]; w_thread[i] = val; w_min = min(w_min, val); w_max = max(w_max, val); } cg::greater max_op; cg::less min_op; auto warp = cg::tiled_partition(cg::this_thread_block()); w_min = cg::reduce(warp, w_min, min_op); w_max = cg::reduce(warp, w_max, max_op); float scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); scale = side ? scale : -scale; float edge = side ? w_min : w_max; float q0 = round(edge / scale); bool at_zero = q0 == 0.0f; scale = at_zero ? scale : edge / q0; float bias = at_zero ? 0 : edge; // Write out the scales and biases size_t gindex = in_index / group_size; if (in_index % group_size == 0) { scales[gindex] = static_cast(scale); biases[gindex] = static_cast(bias); } using OutType = std::conditional_t; OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); if (bits == 8) { output = val; } else { output |= val << (bits * (i % pack_factor)); } if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = warp.shfl_down(val, j); output |= static_cast(sval) << (bits * (j * values_per_reduce + i)); } } } if constexpr (bits == 3 || bits == 6) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } } else if constexpr (bits == 5) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; out[out_index + 3] = (output & 0xff000000) >> 24; out[out_index + 4] = (output & 0xff00000000) >> 32; } } else { if constexpr (writes_per_reduce > 0) { if (out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; } } } } template __global__ void affine_dequantize( const uint8_t* w, const T* scales, const T* biases, T* out, size_t size) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int bytes_per_pack = get_bytes_per_pack(bits); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; if (oindex >= size) { return; } size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; out += oindex; if constexpr (bits == 3) { w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x7) * scale + bias; out[1] = static_cast((w[0] & 0x38) >> 3) * scale + bias; out[2] = (static_cast((w[0] & 0xc0) >> 6) + static_cast((w[1] & 0x1) << 2)) * scale + bias; out[3] = static_cast((w[1] & 0xe) >> 1) * scale + bias; out[4] = static_cast((w[1] & 0x70) >> 4) * scale + bias; out[5] = (static_cast((w[1] & 0x80) >> 7) + static_cast((w[2] & 0x3) << 1)) * scale + bias; out[6] = static_cast((w[2] & 0x1c) >> 2) * scale + bias; out[7] = static_cast((w[2] & 0xe0) >> 5) * scale + bias; } else if constexpr (bits == 5) { w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x1f) * scale + bias; out[1] = (static_cast((w[0] & 0xe0) >> 5) + static_cast((w[1] & 0x3) << 3)) * scale + bias; out[2] = static_cast((w[1] & 0x7c) >> 2) * scale + bias; out[3] = (static_cast((w[1] & 0x80) >> 7) + static_cast((w[2] & 0xf) << 1)) * scale + bias; out[4] = (static_cast((w[2] & 0xf0) >> 4) + static_cast((w[3] & 0x1) << 4)) * scale + bias; out[5] = static_cast((w[3] & 0x3e) >> 1) * scale + bias; out[6] = (static_cast((w[3] & 0xc0) >> 6) + static_cast((w[4] & 0x7) << 2)) * scale + bias; out[7] = static_cast((w[4] & 0xf8) >> 3) * scale + bias; } else if constexpr (bits == 6) { w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x3f) * scale + bias; out[1] = (static_cast((w[0] >> 6) & 0x03) + static_cast((w[1] & 0x0f) << 2)) * scale + bias; out[2] = (static_cast((w[1] >> 4) & 0x0f) + static_cast((w[2] & 0x03) << 4)) * scale + bias; out[3] = static_cast((w[2] >> 2) & 0x3f) * scale + bias; } else { uint32_t val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; } else if (bits == 4) { d = (val >> (bits * i)) & 0x0f; } else if (bits == 8) { d = val; } out[i] = scale * static_cast(d) + bias; } } } } // namespace cu template void dispatch_groups(int group_size, F&& f) { switch (group_size) { case 32: f(std::integral_constant{}); break; case 64: f(std::integral_constant{}); break; case 128: f(std::integral_constant{}); break; } } template void dispatch_bits(int bits, F&& f) { switch (bits) { case 2: f(std::integral_constant{}); break; case 3: f(std::integral_constant{}); break; case 4: f(std::integral_constant{}); break; case 5: f(std::integral_constant{}); break; case 6: f(std::integral_constant{}); break; case 8: f(std::integral_constant{}); break; } } void affine_quantize( const array& w, array& wq, array& scales, array& biases, int group_size_, int bits_, cu::CommandEncoder& enc, const Stream& s) { // Calculate the number of elements per thread int per_thread = group_size_ / WARP_SIZE; size_t size = w.size() / per_thread; // Calculate the thread grid that we need to launch bool large = size > UINT_MAX; auto grid_shape = w.shape(); grid_shape.back() /= per_thread; enc.set_input_array(w); enc.set_output_array(wq); enc.set_output_array(scales); enc.set_output_array(biases); dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { using T = cuda_type_t; auto kernel = cu::affine_quantize; auto [num_blocks, block_dims] = get_launch_args(size, grid_shape, w.strides(), large); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(biases), w.size()); }); }); }); } void affine_dequantize( const array& wq, const array& scales, const array& biases, array& w, int group_size_, int bits_, cu::CommandEncoder& enc, const Stream& s) { // Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in // one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8. constexpr int uint8_per_uint32 = 4; int packs_per_int; switch (bits_) { case 3: case 5: packs_per_int = 8; break; case 6: packs_per_int = 4; break; default: packs_per_int = 8 / bits_; } size_t size = w.size() / packs_per_int; bool large = size > UINT_MAX; auto grid_shape = w.shape(); grid_shape.back() *= uint8_per_uint32; enc.set_input_array(wq); enc.set_input_array(scales); enc.set_input_array(biases); enc.set_output_array(w); dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { using T = cuda_type_t; auto kernel = cu::affine_dequantize; auto [num_blocks, block_dims] = get_launch_args(size, grid_shape, w.strides(), large); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(biases), gpu_ptr(w), w.size()); }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/convert_fp8.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" #include "mlx/fast_primitives.h" namespace mlx::core { void fast::ConvertFP8::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("ConvertFP8::eval_gpu"); auto& in = inputs[0]; auto& out = outputs[0]; auto& s = out.primitive().stream(); if (to_fp8_) { unary_op_gpu(inputs, out, name(), s); } else { unary_op_gpu(inputs, out, name(), s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/cublas_qqmm.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/cublas_qqmm.h" #include #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" #include "mlx/utils.h" namespace mlx::core { namespace { struct QuantModeConfig { cudaDataType_t data_type; cudaDataType_t scale_dtype; cublasLtMatmulMatrixScale_t scale_mode; }; QuantModeConfig get_quant_mode_config(const std::string& mode) { if (mode == "mxfp8") { return { CUDA_R_8F_E4M3, CUDA_R_8F_UE8M0, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0}; } else if (mode == "nvfp4") { return { CUDA_R_4F_E2M1, CUDA_R_8F_UE4M3, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3}; } throw std::runtime_error( fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); } } // namespace CublasQQMM::CublasQQMM( cu::Device& device, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, const std::string& qmode) { auto config = get_quant_mode_config(qmode); // The compute type must be CUBLAS_COMPUTE_32F. // The scale type must be CUDA_R_32F. cudaDataType_t scale_type = CUDA_R_32F; cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; cudaDataType_t output_type = cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM"); init_base( device, scale_type, gemm_compute_type, config.data_type, output_type, a_transposed, a_rows, a_cols, lda, b_transposed, b_rows, b_cols, ldb, batch_count, a_batch_stride, b_batch_stride); a_scale_mode_ = config.scale_mode; b_scale_mode_ = config.scale_mode; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &a_scale_mode_, sizeof(a_scale_mode_))); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &b_scale_mode_, sizeof(b_scale_mode_))); } CublasQQMM::CublasQQMM( cu::Device& device, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, const std::string& qmode) : CublasQQMM( device, a_transposed, a_rows, a_cols, lda, b_transposed, b_rows, b_cols, ldb, batch_count, a_batch_stride, b_batch_stride, out_dtype, qmode) { auto type = cublas_utils::dtype_to_cublas_type( out_dtype, "CublasQQMM"); // must match the output type c_desc_ = cublas_utils::create_matrix_layout( type, b_transposed ? b_rows : b_cols, a_transposed ? a_cols : a_rows, false, ldc, batch_count, c_batch_stride); } void CublasQQMM::run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale, const array& alpha, const array& beta) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); encoder.set_input_array(b_scale); encoder.set_input_array(alpha); encoder.set_input_array(beta); encoder.set_output_array(out); execute( encoder, gpu_ptr(out), gpu_ptr(a), gpu_ptr(b), gpu_ptr(a_scale), gpu_ptr(b_scale), nullptr, gpu_ptr(alpha), gpu_ptr(beta)); } void CublasQQMM::run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); encoder.set_input_array(b_scale); encoder.set_output_array(out); execute( encoder, gpu_ptr(out), gpu_ptr(a), gpu_ptr(b), gpu_ptr(a_scale), gpu_ptr(b_scale), nullptr); } void CublasQQMM::set_scales_ptrs( cu::CommandEncoder& encoder, const void* a_scale, const void* b_scale) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &b_scale, sizeof(b_scale))); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &a_scale, sizeof(a_scale))); } void CublasQQMM::execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* a_scale, const void* b_scale, const void* c, const void* alpha, const void* beta) { set_scales_ptrs(encoder, a_scale, b_scale); // alpha and beta are both should be device pointers for nvfp4 // by default cublas uses host pointers // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); execute_matmul(encoder, out, a, b, c, alpha, beta); } void CublasQQMM::execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* a_scale, const void* b_scale, const void* c, const float alpha /* = 1 */, const float beta /* = 0 */) { set_scales_ptrs(encoder, a_scale, b_scale); // alpha and beta are both should be host pointers cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); const void* alpha_ptr = α const void* beta_ptr = β execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/cublas_qqmm.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include namespace mlx::core { class CublasQQMM : public CublasMatmulBase { public: CublasQQMM( cu::Device& device, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, const std::string& quantization_mode); CublasQQMM( cu::Device& device, bool a_transposed, uint64_t a_rows, uint64_t a_cols, int64_t lda, bool b_transposed, uint64_t b_rows, uint64_t b_cols, int64_t ldb, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, const std::string& quantization_mode); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale, const array& alpha, const array& beta); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale); private: void set_scales_ptrs( cu::CommandEncoder& encoder, const void* a_scale, const void* b_scale); void execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* a_scale, const void* b_scale, const void* c, const void* alpha, const void* beta); void execute( cu::CommandEncoder& encoder, void* out, const void* a, const void* b, const void* a_scale, const void* b_scale, const void* c, const float alpha = 1.0f, const float beta = 0.0f); cublasLtMatmulMatrixScale_t a_scale_mode_; cublasLtMatmulMatrixScale_t b_scale_mode_; cublasLtMatmulMatrixScale_t c_scale_mode_; cublasLtMatmulMatrixScale_t out_scale_mode_; }; } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/fp_quantize.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/quantized.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" #include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/vector_types.cuh" #include "mlx/dtype_utils.h" #include #include #include #include constexpr float F8E4M3_MAX = 448.0f; constexpr float F4E2M1_MAX = 6.0f; namespace mlx::core { namespace cu { template struct Dequantize { __device__ float operator()(uint8_t x) { if constexpr (bits == 8) { return float(*(cutlass::float_e4m3_t*)(&x)); } else { return float(*(cutlass::float_e2m1_t*)(&x)); } } }; template __device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { if constexpr ( (std::is_same::value) || (std::is_same::value)) { T a = x1; T b = x2; out = __hmax2(__habs2(a), __habs2(b)); } else if constexpr (std::is_same::value) { float2 a = x1; float2 b = x2; out.x = fmaxf(fabsf(a.x), fabsf(b.x)); out.y = fmaxf(fabsf(a.y), fabsf(b.y)); } } namespace cg = cooperative_groups; template __global__ void fp_quantize_dequantize( T* w, T* out, size_t size, float* global_scale = nullptr) { const bool use_global_scale = global_scale != nullptr; const float scale_enc = use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; using Tx2 = Vector2_t; uint32_t rbits = 0; // reserved bits for future use auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; size_t thread_idx = tidx + grid_dim_x * size_t(tidy); size_t base_idx = thread_idx * group_size; if (base_idx >= size) { return; } auto w_tile = load_vector(w, thread_idx); float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; absmax_x2(amax_2x, amax_2x, pair); } scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t< use_mx_scale, cutlass::float_ue8m0_t, cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); float scale_enc_b = scale_enc / float(s); float scale_dec = float(s) * inv_scale_enc; AlignedVector w_hat; #pragma unroll for (int i = 0; i < group_size / 8; i++) { auto& w = *reinterpret_cast*>(&w_tile[i * 8]); cutlass::NumericArrayConverter fp32_t; auto scaled = fp32_t(w) * scale_enc_b; cutlass::Array dq; if constexpr (bits == 8) { cutlass::NumericArrayConverter fp8_fp32; auto quant = fp8_fp32(scaled); cutlass::NumericArrayConverter fp32_fp8; dq = fp32_fp8(quant); } else { cutlass::NumericArrayConverter fp4_fp32; auto quant = fp4_fp32(scaled); cutlass::NumericArrayConverter fp32_fp4; dq = fp32_fp4(quant); } cutlass::NumericArrayConverter t_fp32; *reinterpret_cast*>(&w_hat[i * 8]) = t_fp32(dq * scale_dec); } store_vector(out, thread_idx, w_hat); } template __global__ void fp_quantize_rowwise( T* w, uint8_t* out, uint8_t* scales, size_t size, float* global_scale = nullptr) { // NVFP4 conversion: // Global encode scale: (448 × 6) / *global_scale // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b const bool use_global_scale = global_scale != nullptr; const float scale_enc = use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; size_t thread_idx = tidx + grid_dim_x * size_t(tidy); size_t base_idx = thread_idx * group_size; if (base_idx >= size) { return; } auto w_tile = load_vector(w, thread_idx); float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; absmax_x2(amax_2x, amax_2x, pair); } scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t< use_mx_scale, cutlass::float_ue8m0_t, cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); uint8_t q_scale = s.storage; float scale_enc_b = scale_enc / float(s); scales[thread_idx] = q_scale; constexpr int elem_per_byte = bits == 8 ? 1 : 2; AlignedVector quantized; #pragma unroll for (int i = 0; i < group_size / 4; i++) { Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); if constexpr (bits == 8) { uint32_t quantized_val = scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 4]) = quantized_val; } else { uint16_t quantized_val = scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 2]) = quantized_val; } } store_vector(out, thread_idx, quantized); } template __global__ void fp_quantize_columnwise( T* w, uint8_t* out, uint8_t* scales, size_t size, int M, int K, float* global_scale = nullptr) { // Input: [M, K] with strides [1, M] (M-major) // Quantized output: [M, K/elem_per_byte] row-major (K-major) // Scales: [M, K/group_size] row-major (K-major) // Quantize along K (last dimension, groups of group_size elements) const bool use_global_scale = global_scale != nullptr; const float scale_enc = use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; constexpr int elem_per_byte = (bits == 8) ? 1 : 2; constexpr int bytes_per_group = group_size / elem_per_byte; constexpr int rows_per_block = BLOCK_X; constexpr int cols_per_block = BLOCK_Y * group_size; constexpr int local_cols = cols_per_block / elem_per_byte; constexpr int bytes_per_block = rows_per_block * local_cols; constexpr int SMEM_PAD = 4; constexpr int padded_local_cols = local_cols + SMEM_PAD; auto tidx = idx_in_block.x; auto tidy = idx_in_block.y; int num_col_blocks = (K + cols_per_block - 1) / cols_per_block; auto bidx = block_idx.x % num_col_blocks; auto bidy = block_idx.x / num_col_blocks; T thread_data[group_size]; __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols]; __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD]; int row_base = bidy * rows_per_block + tidx; int col_base = bidx * cols_per_block + tidy * group_size; bool valid = (row_base < M) && (col_base + group_size <= K); if (valid) { #pragma unroll for (int i = 0; i < group_size; i++) { auto index = row_base + (col_base + i) * M; thread_data[i] = w[index]; } // Compute scale Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int r = 0; r < group_size; r += 2) { auto pair = Tx2{thread_data[r], thread_data[r + 1]}; absmax_x2(amax_2x, amax_2x, pair); } float scale_dec_b = max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y))); scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t< use_mx_scale, cutlass::float_ue8m0_t, cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); float scale_enc_b = scale_enc / float(s); scales_smem[tidx][tidy] = s.storage; int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; #pragma unroll for (int j = 0; j < group_size / 4; j++) { Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); if constexpr (bits == 8) { uint32_t quantized_val = scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = quantized_val; } else { uint16_t quantized_val = scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = quantized_val; } } } __syncthreads(); int output_cols = K / elem_per_byte; int num_groups_per_row = K / group_size; int linear_tid = tidx + tidy * BLOCK_X; // Write back quantized values #pragma unroll for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) { int local_row = i / local_cols; int local_col = i % local_cols; int global_row = bidy * rows_per_block + local_row; int global_col = bidx * local_cols + local_col; if (global_row < M && global_col < output_cols) { int physical_idx = local_row * padded_local_cols + local_col; out[global_row * output_cols + global_col] = quantized_smem[physical_idx]; } } // Write back scales constexpr int num_scales = BLOCK_X * BLOCK_Y; #pragma unroll for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) { int local_row = i / BLOCK_Y; int local_col = i % BLOCK_Y; int global_row = bidy * BLOCK_X + local_row; int global_col = bidx * BLOCK_Y + local_col; if (global_row < M && global_col < num_groups_per_row) { scales[global_row * num_groups_per_row + global_col] = scales_smem[local_row][local_col]; } } } template __global__ void fp_dequantize( const uint8_t* w, const uint8_t* scales, T* out, size_t size, float* global_scale = nullptr) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr int pack_factor = bits == 8 ? 1 : 2; const bool use_global_scale = global_scale != nullptr; const float inv_scale_enc = use_mx_scale ? 1.0f : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; if (oindex >= size) { return; } size_t gindex = oindex / group_size; using ScaleType = std::conditional_t< use_mx_scale, cutlass::float_ue8m0_t, cutlass::float_e4m3_t>; auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; out += oindex; uint32_t val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 4) { d = (val >> (bits * i)) & 0x0f; } else if (bits == 8) { d = val; } out[i] = static_cast(scale * Dequantize{}(d)); } } inline std::tuple get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; int cols_per_block = BLOCK_Y * group_size; dim3 grid; grid.x = cuda::ceil_div(K, cols_per_block) * cuda::ceil_div(M, rows_per_block); grid.y = 1; grid.z = 1; dim3 block(BLOCK_X, BLOCK_Y); return std::make_tuple(grid, block); } } // namespace cu void fp_quantize_dequantize( const array& w, array& what, int group_size, int bits, const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); if (global_scale.has_value()) { enc.set_input_array(global_scale.value()); } enc.set_output_array(what); dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { auto kernel = cu::fp_quantize_dequantize; if (bits == 8) { kernel = cu::fp_quantize_dequantize; } else if (group_size == 16) { kernel = cu::fp_quantize_dequantize; } bool large = w.size() > UINT_MAX; auto [num_blocks, block_dims] = get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(w), gpu_ptr(what), w.size(), global_scale.has_value() ? gpu_ptr(global_scale.value()) : nullptr); } }); } void fp_quantize( const array& w, array& wq, array& scales, int group_size, int bits, const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); if (global_scale.has_value()) { enc.set_input_array(global_scale.value()); } enc.set_output_array(wq); enc.set_output_array(scales); if (w.strides().back() != 1) { dispatch_float_types(w.dtype(), "fp_quantize_columnwise", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { auto M = w.shape(-2); auto K = w.shape(-1); auto kernel = cu::fp_quantize_columnwise; if (bits == 8) { kernel = cu::fp_quantize_columnwise; } else if (group_size == 16) { kernel = cu::fp_quantize_columnwise; } auto [num_blocks, block_dims] = cu::get_columnwise_quantize_launch_args(w.size(), group_size, M, K); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), w.size(), M, K, global_scale.has_value() ? gpu_ptr(global_scale.value()) : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); } }); } else { dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { auto kernel = cu::fp_quantize_rowwise; if (bits == 8) { kernel = cu::fp_quantize_rowwise; } else if (group_size == 16) { kernel = cu::fp_quantize_rowwise; } bool large = w.size() > UINT_MAX; auto [num_blocks, block_dims] = get_launch_args( w.size(), w.shape(), w.strides(), large, group_size); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), w.size(), global_scale.has_value() ? gpu_ptr(global_scale.value()) : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); } }); } } void fp_dequantize( const array& wq, const array& scales, array& w, int group_size, int bits, const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { constexpr int uint8_per_uint32 = 4; int packs_per_int = 8 / bits; size_t size = w.size() / packs_per_int; bool large = size > UINT_MAX; auto grid_shape = w.shape(); grid_shape.back() *= uint8_per_uint32; enc.set_input_array(wq); enc.set_input_array(scales); if (global_scale.has_value()) { enc.set_input_array(global_scale.value()); } enc.set_output_array(w); dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { auto kernel = cu::fp_dequantize; if (bits == 8) { kernel = cu::fp_dequantize; } else if (group_size == 16) { kernel = cu::fp_dequantize; } auto [num_blocks, block_dims] = get_launch_args(size, grid_shape, w.strides(), large); enc.add_kernel_node( kernel, num_blocks, block_dims, gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(w), w.size(), global_scale.has_value() ? gpu_ptr(global_scale.value()) : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not dequantize to output with type float64."); } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/mxfp8_quantize.cuh ================================================ #pragma once #include "mlx/backend/cuda/vector_types.cuh" #include namespace mlx::core::cu { // Place holder for future fast path implementation template __device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4( const Vector4_t& input, const float scale, uint32_t rbits) { cutlass::NumericArrayConverter fp32_t; auto scaled = fp32_t(*reinterpret_cast*>(&input)) * scale; cutlass::NumericArrayConverter fp8_fp32; auto quant = fp8_fp32(scaled); return *reinterpret_cast(&quant); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/quantized/no_qqmm_impl.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qqmm_impl.h" namespace mlx::core { void qqmm_impl( cu::CommandEncoder&, int, int, int, bool, int64_t, bool, int64_t, array&, const array&, const array&, const array&, const array&, QuantizationMode, const GemmScalars&) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher."); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/nvfp4_quantize.cuh ================================================ #pragma once #include "mlx/backend/cuda/vector_types.cuh" #include namespace mlx::core::cu { using bf16x4 = Vector4_t<__nv_bfloat16>; using fp16x4 = Vector4_t<__half>; using f32x4 = Vector4_t; template __device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t& input, const float scale) { // Fallback implementation for architectures that do not support cvt // instructions or for cuda versions with no fp4 support (< 12.8) -> scalar cutlass::NumericArrayConverter fp32_t; auto scaled = fp32_t(*reinterpret_cast*>(&input)) * scale; cutlass::NumericArrayConverter fp4_fp32; auto quant = fp4_fp32(scaled); return *reinterpret_cast(&quant); } #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ defined(__CUDA_ARCH_SPECIFIC__) __device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b16 x0_bf16; \n\t" // first bf16 ".reg.b16 x1_bf16; \n\t" // second bf16 ".reg.b16 x2_bf16; \n\t" // third bf16 ".reg.b16 x3_bf16; \n\t" // fourth bf16 ".reg.b32 x0; \n\t" // to hold scaled first ".reg.b32 x1; \n\t" // to hold scaled second ".reg.b32 x2; \n\t" // to hold scaled third ".reg.b32 x3; \n\t" // to hold scaled fourth ".reg.b64 x01; \n\t" // to hold vector mul ".reg.b64 x23; \n\t" ".reg.b8 q0; \n\t" // output byte fp4x2 (first pair) ".reg.b8 q1; \n\t" // output byte fp4x2 (second pair) "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16 "cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32 "cvt.f32.bf16 x1, x1_bf16; \n\t" "cvt.f32.bf16 x2, x2_bf16; \n\t" "cvt.f32.bf16 x3, x3_bf16; \n\t" "mov.b64 x01, {x0, x1}; \n\t" "mul.f32x2 x01, x01, %2; \n\t" // scale first pair "mov.b64 x23, {x2, x3}; \n\t" "mul.f32x2 x23, x23, %2; \n\t" // scale second pair "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first // pair "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second // pair "mov.b16 %0, {q0, q1}; \n\t" // pack to output "}" : "=h"(out_fp4x4) : "l"(reinterpret_cast(input_bf16x4)), "l"(reinterpret_cast( scale))); // here cast is needed becuase an asm operand must have // scalar type return out_fp4x4; } __device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs( const bf16x4 input_bf16x4, const float2 scale, uint32_t rbits) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b16 x0_bf16; \n\t" ".reg.b16 x1_bf16; \n\t" ".reg.b16 x2_bf16; \n\t" ".reg.b16 x3_bf16; \n\t" ".reg.b32 x0; \n\t" ".reg.b32 x1; \n\t" ".reg.b32 x2; \n\t" ".reg.b32 x3; \n\t" ".reg.b64 x01; \n\t" ".reg.b64 x23; \n\t" ".reg.b16 q0; \n\t" "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" "cvt.f32.bf16 x0, x0_bf16; \n\t" "cvt.f32.bf16 x1, x1_bf16; \n\t" "cvt.f32.bf16 x2, x2_bf16; \n\t" "cvt.f32.bf16 x3, x3_bf16; \n\t" "mov.b64 x01, {x0, x1}; \n\t" "mul.f32x2 x01, x01, %2; \n\t" "mov.b64 x23, {x2, x3}; \n\t" "mul.f32x2 x23, x23, %2; \n\t" "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t" "}" : "=h"(out_fp4x4) : "l"(reinterpret_cast(input_bf16x4)), "l"(reinterpret_cast(scale)), "r"(rbits)); return out_fp4x4; } __device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn( const float2 input_fp32x2_0, const float2 input_fp32x2_1, const float2 scale) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b32 x0; \n\t" ".reg.b32 x1; \n\t" ".reg.b32 x2; \n\t" ".reg.b32 x3; \n\t" ".reg.b64 x01; \n\t" ".reg.b64 x23; \n\t" ".reg.b8 q0; \n\t" ".reg.b8 q1; \n\t" "mov.b64 x01, {%1, %2}; \n\t" "mul.f32x2 x01, x01, %5; \n\t" "mov.b64 x23, {%3, %4}; \n\t" "mul.f32x2 x23, x23, %5; \n\t" "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" "mov.b16 %0, {q0, q1}; \n\t" "}" : "=h"(out_fp4x4) : "f"(input_fp32x2_0.x), "f"(input_fp32x2_0.y), "f"(input_fp32x2_1.x), "f"(input_fp32x2_1.y), "l"(reinterpret_cast(scale))); return out_fp4x4; } __device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs( const float2 input_fp32x2_0, const float2 input_fp32x2_1, const float2 scale, uint32_t rbits) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b32 x0; \n\t" ".reg.b32 x1; \n\t" ".reg.b32 x2; \n\t" ".reg.b32 x3; \n\t" ".reg.b64 x01; \n\t" ".reg.b64 x23; \n\t" ".reg.b16 q0; \n\t" "mov.b64 x01, {%1, %2}; \n\t" "mul.f32x2 x01, x01, %5; \n\t" "mov.b64 x23, {%3, %4}; \n\t" "mul.f32x2 x23, x23, %5; \n\t" "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t" "}" : "=h"(out_fp4x4) : "f"(input_fp32x2_0.x), "f"(input_fp32x2_0.y), "f"(input_fp32x2_1.x), "f"(input_fp32x2_1.y), "l"(reinterpret_cast(scale)), "r"(rbits)); return out_fp4x4; } __device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b16 x0_fp16; \n\t" ".reg.b16 x1_fp16; \n\t" ".reg.b16 x2_fp16; \n\t" ".reg.b16 x3_fp16; \n\t" ".reg.b32 x0; \n\t" ".reg.b32 x1; \n\t" ".reg.b32 x2; \n\t" ".reg.b32 x3; \n\t" ".reg.b64 x01; \n\t" ".reg.b64 x23; \n\t" ".reg.b8 q0; \n\t" ".reg.b8 q1; \n\t" "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t" "cvt.f32.f16 x0, x0_fp16; \n\t" "cvt.f32.f16 x1, x1_fp16; \n\t" "cvt.f32.f16 x2, x2_fp16; \n\t" "cvt.f32.f16 x3, x3_fp16; \n\t" "mov.b64 x01, {x0, x1}; \n\t" "mul.f32x2 x01, x01, %2; \n\t" "mov.b64 x23, {x2, x3}; \n\t" "mul.f32x2 x23, x23, %2; \n\t" "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" "mov.b16 %0, {q0, q1}; \n\t" "}" : "=h"(out_fp4x4) : "l"(reinterpret_cast(input_fp16x4)), "l"(reinterpret_cast(scale))); return out_fp4x4; } __device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs( const fp16x4 input_fp16x4, const float2 scale, uint32_t rbits) { uint16_t out_fp4x4 = 0; asm volatile( "{\n" ".reg.b16 x0_fp16; \n\t" ".reg.b16 x1_fp16; \n\t" ".reg.b16 x2_fp16; \n\t" ".reg.b16 x3_fp16; \n\t" ".reg.b32 x0; \n\t" ".reg.b32 x1; \n\t" ".reg.b32 x2; \n\t" ".reg.b32 x3; \n\t" ".reg.b64 x01; \n\t" ".reg.b64 x23; \n\t" ".reg.b16 q0; \n\t" "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t" "cvt.f32.f16 x0, x0_fp16; \n\t" "cvt.f32.f16 x1, x1_fp16; \n\t" "cvt.f32.f16 x2, x2_fp16; \n\t" "cvt.f32.f16 x3, x3_fp16; \n\t" "mov.b64 x01, {x0, x1}; \n\t" "mul.f32x2 x01, x01, %2; \n\t" "mov.b64 x23, {x2, x3}; \n\t" "mul.f32x2 x23, x23, %2; \n\t" "mov.b64 {x0, x1}, x01; \n\t" "mov.b64 {x2, x3}, x23; \n\t" "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t" "}" : "=h"(out_fp4x4) : "l"(reinterpret_cast(input_fp16x4)), "l"(reinterpret_cast(scale)), "r"(rbits)); return out_fp4x4; } template __device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4( const bf16x4 input, const float scale, uint32_t rbits) { float2 scale_fp32x2 = make_float2(scale, scale); if constexpr (USE_SR) { return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits); } else { return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2); } } template __device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4( const fp16x4 input, const float scale, uint32_t rbits) { float2 scale_fp32x2 = make_float2(scale, scale); if constexpr (USE_SR) { return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits); } else { return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2); } } template __device__ __forceinline__ uint16_t scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) { float2 scale_fp32x2 = make_float2(scale, scale); float2 input_fp32x2_0 = make_float2(input.x, input.y); float2 input_fp32x2_1 = make_float2(input.z, input.w); if constexpr (USE_SR) { return scale_cvt_fp32x4_to_fp4x4_rs( input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits); } else { return scale_cvt_fp32x4_to_fp4x4_rn( input_fp32x2_0, input_fp32x2_1, scale_fp32x2); } } template __device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast( const Vector4_t input, const float scale, uint32_t rbits) { if constexpr (std::is_same::value) { return scale_cvt_bf16x4_to_fp4x4(input, scale, rbits); } else if constexpr (std::is_same::value) { return scale_cvt_fp16x4_to_fp4x4(input, scale, rbits); } else { return scale_cvt_f32x4_to_fp4x4(input, scale, rbits); } } #endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) template __device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4( const Vector4_t& input, const float scale, uint32_t rbits) { #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) return scale_cvt_Tx4_to_fp4x4_fast(input, scale, rbits); #else static_assert( !USE_SR, "Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000."); return scale_cvt_Tx4_to_fp4x4_fallback(input, scale); #endif } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/quantized/qmm/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu ${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m16.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m32.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m64.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n32_m1.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n64_m2.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n128_m2.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n256_m2.cu) ================================================ FILE: mlx/backend/cuda/quantized/qmm/fp_qmv.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/quantized.h" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/dtype_utils.h" #include #include #include #include namespace mlx::core { constexpr int rows_per_block = 8; namespace cu { namespace cg = cooperative_groups; template __device__ void adjust_matrix_offsets( const T*& x, const uint32_t*& w, const uint8_t*& scales, T*& y, int output_stride, const int& x_batch_ndims, const Shape x_shape, const Strides x_strides, const int& w_batch_ndims, const Shape w_shape, const Strides w_strides, const Strides s_strides) { uint32_t idx = cg::this_grid().block_index().z; if (x_batch_ndims == 1) { x += idx * x_strides[0]; } else { x += elem_to_loc(idx, x_shape.data(), x_strides.data(), x_batch_ndims); } if (w_batch_ndims == 1) { w += idx * w_strides[0]; scales += idx * s_strides[0]; } else { auto [w_idx, s_idx] = elem_to_loc( idx, w_shape.data(), w_strides.data(), s_strides.data(), w_batch_ndims); w += w_idx; scales += s_idx; } y += idx * output_stride; } template < typename T, int rows_per_block, int n_per_thread, int bits, int group_size, bool use_mx_scale> __device__ void fp_qmv_impl( const uint32_t* mat, const uint8_t* scales_, const T* vec, T* out, int rows, int cols) { auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); constexpr int vals_per_item = bits == 8 ? 4 : 8; constexpr int nv_per_thread = vals_per_item * n_per_thread; auto g_idx = block.group_index(); auto t_idx = block.thread_index(); int row = g_idx.y * rows_per_block + t_idx.y; vec += g_idx.x * cols; out += g_idx.x * rows; using ScaleType = std::conditional_t< use_mx_scale, cutlass::float_ue8m0_t, cutlass::float_e4m3_t>; auto scales = (ScaleType*)(scales_); auto packed_cols = cols / vals_per_item; if (row < rows) { constexpr int scales_per_step = std::max(nv_per_thread / group_size, 1); constexpr int scale_step = (WARP_SIZE * nv_per_thread) / group_size; constexpr int n_per_step = n_per_thread / scales_per_step; // Offset scales to correct row scales += row * (cols / group_size) + (warp.thread_rank() * nv_per_thread) / group_size; float sum = 0.0f; for (int col = n_per_thread * warp.thread_rank(); col < packed_cols; col += (WARP_SIZE * n_per_thread)) { auto local_vec = unsafe_load_vector(vec + vals_per_item * col, 0); auto local_mat = unsafe_load_vector(mat + row * packed_cols + col, 0); #pragma unroll for (int i = 0; i < scales_per_step; ++i) { float2 local_sum = {0.0f, 0.0f}; #pragma unroll for (int j = 0; j < n_per_step; ++j) { int k = n_per_step * i + j; if constexpr (bits == 8) { cutlass::NumericArrayConverter converter; auto v = converter( *reinterpret_cast*>( &local_mat[k])); local_sum.x += v[0] * static_cast(local_vec[vals_per_item * k]); local_sum.x += v[1] * static_cast(local_vec[vals_per_item * k + 1]); local_sum.y += v[2] * static_cast(local_vec[vals_per_item * k + 2]); local_sum.y += v[3] * static_cast(local_vec[vals_per_item * k + 3]); } else { cutlass::NumericArrayConverter converter; auto v = converter( *reinterpret_cast*>( &local_mat[k])); local_sum.x += v[0] * static_cast(local_vec[vals_per_item * k]); local_sum.y += v[1] * static_cast(local_vec[vals_per_item * k + 1]); local_sum.x += v[2] * static_cast(local_vec[vals_per_item * k + 2]); local_sum.y += v[3] * static_cast(local_vec[vals_per_item * k + 3]); local_sum.x += v[4] * static_cast(local_vec[vals_per_item * k + 4]); local_sum.y += v[5] * static_cast(local_vec[vals_per_item * k + 5]); local_sum.x += v[6] * static_cast(local_vec[vals_per_item * k + 6]); local_sum.y += v[7] * static_cast(local_vec[vals_per_item * k + 7]); } } sum += (local_sum.x + local_sum.y) * float(scales[i]); } scales += scale_step; } sum = cg::reduce(warp, sum, cg::plus{}); if (warp.thread_rank() == 0) { out[row] = static_cast(sum); } } } template < typename T, int rows_per_block, int n_per_thread, int bits, int group_size, bool use_mx_scale> __global__ void fp_qmv_single( const uint32_t* mat, const uint8_t* scales, const T* vec, T* out, int rows, int cols) { fp_qmv_impl( mat, scales, vec, out, rows, cols); } template < typename T, int rows_per_block, int n_per_thread, int bits, int group_size, bool use_mx_scale> __global__ void fp_qmv_batched( const uint32_t* mat, const uint8_t* scales, const T* vec, T* out, int rows, int cols, int vec_batch_ndims, const __grid_constant__ Shape vec_shape, const __grid_constant__ Strides vec_strides, int mat_batch_ndims, const __grid_constant__ Shape mat_shape, const __grid_constant__ Strides mat_strides, const __grid_constant__ Strides scales_strides) { adjust_matrix_offsets( vec, mat, scales, out, rows * vec_shape[vec_batch_ndims], vec_batch_ndims, vec_shape, vec_strides, mat_batch_ndims, mat_shape, mat_strides, scales_strides); fp_qmv_impl( mat, scales, vec, out, rows, cols); } } // namespace cu template void dispatch_1_2_4(int n, F&& f) { switch (n) { case 1: f(std::integral_constant{}); break; case 2: f(std::integral_constant{}); break; case 4: f(std::integral_constant{}); break; } } void fp_qmv( const array& x, const array& w, const array& scales_, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s) { uint32_t M = x.shape(-2); uint32_t N = out.shape(-1); uint32_t K = x.shape(-1); uint32_t B = out.size() / (M * N); // Make sure the last two dims of x and w, s, b are contiguous. This should // be relaxed for x. array vec = ensure_row_contiguous_matrix(x, encoder, s); array mat = ensure_row_contiguous_matrix(w, encoder, s); array scales = ensure_row_contiguous_matrix(scales_, encoder, s); encoder.set_input_array(mat); encoder.set_input_array(scales); encoder.set_input_array(vec); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "qmv", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { dim3 block_dims{WARP_SIZE, rows_per_block}; uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block; const uint32_t* mat_ptr = gpu_ptr(mat); const T* vec_ptr = gpu_ptr(vec); int n = 1; if (K % 32 == 0 && cu::is_aligned<4>(mat_ptr) && ((bits == 4 && cu::is_aligned<8>(vec_ptr)) || cu::is_aligned<4>(vec_ptr))) { n = 4; } else if ( cu::is_aligned<2>(mat_ptr) && ((bits == 4 && cu::is_aligned<4>(vec_ptr)) || cu::is_aligned<2>(vec_ptr))) { n = 2; } dispatch_1_2_4(n, [&](auto n) { if (B == 1) { auto kernel = cu::fp_qmv_single; if (bits == 8) { kernel = cu::fp_qmv_single; } else if (group_size == 16) { kernel = cu::fp_qmv_single; } encoder.add_kernel_node( kernel, {uint32_t(x.size() / K), blocks_y}, block_dims, mat_ptr, gpu_ptr(scales), vec_ptr, gpu_ptr(out), N, K); } else { auto kernel = cu::fp_qmv_batched; if (bits == 8) { kernel = cu::fp_qmv_batched; } else if (group_size == 16) { kernel = cu::fp_qmv_batched; } encoder.add_kernel_node( kernel, {M, blocks_y, B}, block_dims, mat_ptr, gpu_ptr(scales), vec_ptr, gpu_ptr(out), N, K, vec.ndim() - 2, const_param(vec.shape()), const_param(vec.strides()), mat.ndim() - 2, const_param(mat.shape()), const_param(mat.strides()), const_param(scales.strides())); } }); } }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include namespace mlx::core { #if defined(MLX_CUDA_SM90A_ENABLED) // Defined in qmm_impl_sm90_xxx.cu files. template void qmm_impl_sm90( const array& x, const array& w, const array& scales, const array& biases, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s); #endif // defined(MLX_CUDA_SM90A_ENABLED) bool supports_qmm_sm90( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device) { if (device.compute_capability_major() != 9) { return false; } int k = x.shape(-1); if (k % 64 != 0) { return false; } if (!biases) { return false; } if (!x.flags().row_contiguous || !w.flags().row_contiguous || !scales.flags().row_contiguous || !biases->flags().row_contiguous) { return false; } if (!transpose) { return false; } if (bits % 2 != 0) { return false; } if (group_size < k) { return false; } if (mode != QuantizationMode::Affine) { return false; } return true; } void qmm_sm90( const array& x, const array& w, const array& scales, const array& biases, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s) { #if defined(MLX_CUDA_SM90A_ENABLED) auto dispatch = [&]() { using cute::Int; using TileShapeMN = cute::Shape, Int>; using ClusterShape = cute::Shape, Int<1>, Int<1>>; qmm_impl_sm90( x, w, scales, biases, out, bits, group_size, encoder, s); }; int m = out.shape(-2); if (m <= 16) { dispatch.template operator()<128, 16, 1>(); } else if (m <= 32) { dispatch.template operator()<128, 32, 1>(); } else if (m <= 64) { dispatch.template operator()<128, 64, 2>(); } else if (m <= 128) { dispatch.template operator()<128, 128, 2>(); } else { dispatch.template operator()<128, 256, 2>(); } #else throw std::runtime_error( "[quantized_matmul] Hopper-only kernel is not available."); #endif // defined(MLX_CUDA_SM90A_ENABLED) } // Defined in qmm_impl_sm80_xxx.cu files. template void qmm_impl_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder); bool supports_qmm_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device) { if (device.compute_capability_major() < 8) { return false; } int n = out.shape(-1); int k = x.shape(-1); if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) { return false; } if (!x.flags().row_contiguous || !w.flags().row_contiguous || !scales.flags().row_contiguous) { return false; } if (biases && !biases->flags().row_contiguous) { return false; } if (x.dtype() != float16 && x.dtype() != bfloat16) { return false; } if (!transpose) { return false; } if (bits != 4 && bits != 8) { return false; } return true; } void qmm_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { auto dispatch = [&]() { qmm_impl_sm80( x, w, scales, biases, out, bits, group_size, mode, encoder); }; int m = out.shape(-2); if (m <= 16) { dispatch.template operator()<16>(); } else if (m <= 32) { dispatch.template operator()<32>(); } else { dispatch.template operator()<64>(); } } bool supports_fp_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device) { // The fp_qmv kernel uses less registers and is faster for sm120. For sm80/90 // the qmv kernel is faster. We didn't test sm89/100. if (device.compute_capability_major() <= 9) { return false; } bool non_batched = w.ndim() == 2; int k = x.shape(-1); int n = out.shape(-1); int vec_batch = non_batched ? x.size() / k : x.shape(-2); if (vec_batch > 8) { return false; } if (!transpose) { return false; } if (mode == QuantizationMode::Affine) { return false; } return true; } bool supports_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device) { int k = x.shape(-1); if (k % 8 != 0) { return false; } if (!x.flags().row_contiguous || !w.flags().row_contiguous || !scales.flags().row_contiguous) { return false; } if (biases && !biases->flags().row_contiguous) { return false; } if (!transpose) { return false; } return true; } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/primitives.h" #include namespace mlx::core { bool supports_qmm_sm90( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device); void qmm_sm90( const array& x, const array& w, const array& scales, const array& biases, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s); bool supports_qmm_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device); void qmm_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder); bool supports_fp_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device); void fp_qmv( const array& x, const array& w, const array& scales, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s); bool supports_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, const array& out, bool transpose, int bits, int group_size, QuantizationMode mode, cu::Device& device); void qmv( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/dtype_utils.h" #include #include // clang-format off // We can't put kernel code in mlx::core due to name conflicts of "Shape". namespace cutlass_gemm { using namespace cute; template constexpr bool has_zero_point_v = !cutlass::has_negative_zero_v; template union SharedStorage { struct { ArrayEngine> A; ArrayEngine> B; } mainloop; struct { ArrayEngine> C; } epilogue; }; template __device__ __forceinline__ void dequant(const Q& w, const S& s, const Z& z, T out) { // Scale must be one element. CUTE_STATIC_ASSERT_V(cosize(s.layout()) == Int<1>{}); CUTE_STATIC_ASSERT_V(cosize(z.layout()) == Int<1>{}); // Quant must be contiguous. auto layout = coalesce(w.layout()); CUTE_STATIC_ASSERT_V(stride(layout) == Int<1>{}); // Use cutlass for conversions. constexpr int N = size(layout); using Element = typename T::value_type; using Quant = typename Q::value_type; auto& w_vec = *(reinterpret_cast*>(raw_pointer_cast(w.data()))); Element scale{s[0]}; cutlass::NumericArrayConverter converter; auto w_dq = converter(w_vec) * scale; if constexpr (has_zero_point_v) { Element zero_point{z[0]}; w_dq = w_dq + zero_point; } copy(make_tensor(make_rmem_ptr(&w_dq), out.layout()), out); } template __global__ void qmm_sm80_kernel( ProblemShape shape_MNKL, CtaTiler cta_tiler, const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a, const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b, Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c, const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) { CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma)); CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); int thread_idx = int(threadIdx.x); auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); // Represent the full tensors. Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) // Get batch slice. Tensor mA = mA_mkl(_,_,l_coord); // (M,K) Tensor mB = mB_nkl(_,_,l_coord); // (N,K) Tensor mC = mC_mnl(_,_,l_coord); // (M,N) Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) // Get the appropriate blocks for this thread block. auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) // Shared memory buffers. extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& smem = *reinterpret_cast(shared_memory); Tensor sA = make_tensor(make_smem_ptr(smem.mainloop.A.begin()), sA_layout); // (BLK_M,BLK_K) Tensor sB = make_tensor(make_smem_ptr(smem.mainloop.B.begin()), sB_layout); // (BLK_N,BLK_K) Tensor sC = make_tensor(make_smem_ptr(smem.epilogue.C.begin()), sC_layout); // (BLK_M,BLK_N) // Partition the copying of A/B/C tiles across the threads. ThrCopy g2s_thr_copy_a = g2s_copy_a.get_slice(thread_idx); Tensor tAgA = g2s_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) Tensor tAsA = g2s_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) ThrCopy g2s_thr_copy_b = g2s_copy_b.get_slice(thread_idx); Tensor tBgB = g2s_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) Tensor tBsB = g2s_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) ThrCopy s2g_thr_copy_c = s2g_copy_c.get_slice(thread_idx); Tensor s2g_tCsC = s2g_thr_copy_c.partition_S(sC); // (CCPY,CCPY_M,CCPY_N) Tensor s2g_tCgC = s2g_thr_copy_c.partition_D(gC); // (CCPY,CCPY_M,CCPY_N) // MMA. ThrMMA thr_mma = mma.get_slice(thread_idx); Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) Tensor tCsB = thr_mma.partition_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) Tensor tCrB = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) Tensor tCrB_dq = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) Tensor tCrC_accu = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) Tensor tCrC = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) Tensor tCgS = thr_mma.partition_B(gS); // (MMA,MMA_N,MMA_K,k) Tensor tCrS = make_tensor_like(tCgS(_,_,_,0)); // (MMA,MMA_N,MMA_K) Tensor tCgZ = thr_mma.partition_B(gZ); // (MMA,MMA_N,MMA_K,k) Tensor tCrZ = make_tensor_like(tCgZ(_,_,_,0)); // (MMA,MMA_N,MMA_K) // Copy Atom retiling. TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(thread_idx); Tensor s2r_tCsA = s2r_thr_copy_a.partition_S(sA); // (ACPY,MMA_M,MMA_K,PIPE) Tensor s2r_tCrA = s2r_thr_copy_a.retile_D(tCrA); // (ACPY,MMA_M,MMA_K) TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(thread_idx); Tensor s2r_tCsB = s2r_thr_copy_b.partition_S(sB); // (BCPY,MMA_N,MMA_K,PIPE) Tensor s2r_tCrB = s2r_thr_copy_b.retile_D(tCrB); // (BCPY,MMA_N,MMA_K) TiledCopy r2s_copy_c = make_tiled_copy_C(r2s_atom_c, mma); ThrCopy r2s_thr_copy_c = r2s_copy_c.get_slice(thread_idx); Tensor r2s_tCrC = r2s_thr_copy_c.retile_S(tCrC); // (CCPY,MMA_M,MMA_N) Tensor r2s_tCsC = r2s_thr_copy_c.partition_D(sC); // (CCPY,MMA_M,MMA_N) TiledCopy g2r_copy_s = make_tiled_copy_B(g2r_atom_s, mma); ThrCopy g2r_thr_copy_s = g2r_copy_s.get_slice(thread_idx); Tensor g2r_tCgS = g2r_thr_copy_s.partition_S(gS); // (BCPY,MMA_N,MMA_K,k) Tensor g2r_tCrS = g2r_thr_copy_s.retile_D(tCrS); // (BCPY,MMA_N,MMA_K) Tensor g2r_tCgZ = g2r_thr_copy_s.partition_S(gZ); // (BCPY,MMA_N,MMA_K,k) Tensor g2r_tCrZ = g2r_thr_copy_s.retile_D(tCrZ); // (BCPY,MMA_N,MMA_K) // Predicates for m bound. auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) Tensor tCpC = make_tensor(make_shape(size<1>(s2g_tCsC), size<2>(s2g_tCsC)), Stride<_1,_0>{}); // (CPY_M,CPY_N) Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) Tensor cC = make_identity_tensor(make_shape(size<0>(sC), size<1>(sC))); // (BLK_M,BLK_N) Tensor tAcA = g2s_thr_copy_a.partition_D(cA); // (CPY,CPY_M,CPY_K) Tensor tCcC = s2g_thr_copy_c.partition_D(cC); // (CPY,CPY_M,CPY_N) CUTE_UNROLL for (int m = 0; m < size<0>(tApA); ++m) { tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; } CUTE_UNROLL for (int m = 0; m < size<0>(tCpC); ++m) { tCpC(m,0) = get<0>(tCcC(0,m,0)) < m_max_coord; } auto K_PIPE_MAX = size<3>(tAsA); int smem_pipe_read = 0; int smem_pipe_write = 0; // Copy A/B: GMEM => SMEM. auto fetch_gmem = [&](int tile) { copy_if(g2s_copy_a, tApA, tAgA(_,_,_,tile), tAsA(_,_,_,smem_pipe_write)); copy(g2s_copy_b, tBgB(_,_,_,tile), tBsB(_,_,_,smem_pipe_write)); cp_async_fence(); smem_pipe_write = (smem_pipe_write + 1) % K_PIPE_MAX; }; // Copy S/Z: GMEM => RMEM. auto fetch_scales = [&](int tile) { copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); if constexpr (has_zero_point_v) { copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); } }; // Copy A/B: SMEM => RMEM. auto fetch_smem = [&](auto block) { copy(s2r_atom_a, s2r_tCsA(_,_,block,smem_pipe_read), s2r_tCrA(_,_,block)); copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block)); CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { dequant(tCrB(_,n,block), tCrS(_,n,block), tCrZ(_,n,block), tCrB_dq(_,n,block)); } }; auto K_TILE_MAX = size<3>(tAgA); auto K_BLOCK_MAX = size<2>(tCrA); // Prefetch beginning tiles. int tile_pipe = 0; CUTE_UNROLL for (; tile_pipe < K_PIPE_MAX - 1; ++tile_pipe) { fetch_gmem(tile_pipe); } // Clear accumulators. clear(tCrC_accu); // Prefetch first block. if constexpr (K_BLOCK_MAX > 1) { cp_async_wait(); __syncthreads(); fetch_scales(0); fetch_smem(Int<0>{}); } // Loop over CTA tiles. for (int tile = 0; tile < K_TILE_MAX; ++tile) { // Unroll MMA blocks. CUTE_UNROLL for (int block = 0; block < K_BLOCK_MAX; ++block) { // Wait for last tile. if (block == K_BLOCK_MAX - 1) { smem_pipe_read = (smem_pipe_read + 1) % K_PIPE_MAX; cp_async_wait(); __syncthreads(); fetch_scales((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); } // Prefetch next block. fetch_smem((block + 1) % K_BLOCK_MAX); // Prefetch next tile. if (block == 0) { fetch_gmem(tile_pipe); tile_pipe = (tile_pipe + 1 < K_TILE_MAX) ? tile_pipe + 1 : tile_pipe; } // MMA. gemm(mma, tCrA(_,_,block), tCrB_dq(_,_,block), tCrC_accu); } } // Epilogue. CUTE_UNROLL for (int i = 0; i < size(tCrC_accu); i++) { tCrC(i) = Element(tCrC_accu(i)); } copy(r2s_copy_c, r2s_tCrC, r2s_tCsC); __syncthreads(); copy_if(s2g_copy_c, tCpC, s2g_tCsC, s2g_tCgC); } template inline constexpr auto make_mma_atom() { if constexpr (std::is_same_v) { return SM80_16x8x16_F32F16F16F32_TN{}; } if constexpr (std::is_same_v) { return SM80_16x8x16_F32BF16BF16F32_TN{}; } } template inline constexpr auto make_tiled_mma() { constexpr auto atom = make_mma_atom(); if constexpr (TileM >= 32) { return make_tiled_mma(atom, Layout>{}, Tile<_32,_32,_16>{}); } else { return make_tiled_mma(atom, Layout>{}, Tile<_16,_32,_16>{}); } } template typename Atom, typename NumThreads> inline auto make_tiled_copy(NumThreads num_threads) { return make_tiled_copy( Copy_Atom>, T>{}, make_layout(make_shape(Int{}, Int<8>{}), LayoutRight{}), make_layout(make_shape(Int<1>{}, Int>{}))); } template void qmm_sm80( const Element* A, const Quant* B, const Scale* S, const Element* Z, Element* C, int m, int n, int k, int l, GroupSize group_size, F&& launch_kernel) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) // Define TN strides (mixed). auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) // Define CTA tile sizes (static). auto bM = Int{}; auto bN = Int<128>{}; auto bK = Int{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K) // Define MMA. TiledMMA mma = make_tiled_mma(); auto num_threads = size(mma); // Define the A/B smem layouts (static). auto swizzle_ab = composition(Swizzle<3,3,3>{}, Layout>, Stride<_8,Stride<_1,_64>>>{}); auto bP = Int<3>{}; // pipeline auto sA_layout = tile_to_shape(swizzle_ab, make_shape(bM, bK, bP)); auto sB_layout = tile_to_shape(swizzle_ab, make_shape(bN, bK, bP)); // Define the C smem layouts (static). // TODO: Find a better swizzle. auto sC_layout = tile_to_shape(swizzle_ab, make_shape(bM, bN)); // Define the scales/biases smem layouts (static). auto bS = ceil_div(bK, group_size); auto sS_layout = make_layout(make_shape(bN, make_shape(group_size, bS)), make_stride(bS, Stride<_0, _1>{})); // Define layout of scales/biases (mixed). auto S_layout = make_layout( make_shape(n, make_shape(group_size, k / group_size), l), make_stride(k / group_size, Stride<_0, _1>{}, n * k / group_size)); // Atoms. constexpr int element_bits = sizeof_bits_v; constexpr int quant_bits = sizeof_bits_v; constexpr int qload = 128 / (element_bits / quant_bits); TiledCopy g2s_copy_a = make_tiled_copy(num_threads); TiledCopy g2s_copy_b = make_tiled_copy(num_threads); TiledCopy s2g_copy_c = make_tiled_copy(num_threads); Copy_Atom s2r_atom_a; Copy_Atom>, Quant> s2r_atom_b; Copy_Atom>, Element> r2s_atom_c; Copy_Atom, Scale> g2r_atom_s; auto* kernel = &qmm_sm80_kernel< decltype(prob_shape), decltype(cta_tiler), Element, Quant, Scale, decltype(dA), decltype(sA_layout), decltype(g2s_copy_a), decltype(s2r_atom_a), decltype(dB), decltype(sB_layout), decltype(g2s_copy_b), decltype(s2r_atom_b), decltype(dC), decltype(sC_layout), decltype(s2g_copy_c), decltype(r2s_atom_c), decltype(S_layout), decltype(g2r_atom_s), decltype(mma)>; // Set L1 to be SMEM only. size_t smem_bytes = sizeof(SharedStorage); cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l); dim3 block_dims(num_threads); void* args[] = { &prob_shape, &cta_tiler, &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a, &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b, &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c, &S, &Z, &S_layout, &g2r_atom_s, &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } } // namespace cutlass_gemm // clang-format on namespace mlx::core { template inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { if (dtype == float16) { f.template operator()(); } else if (dtype == bfloat16) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); } } template inline void dispatch_groups(int group_size, const char* tag, F&& f) { if (group_size == 32) { f.template operator()<32>(); } else if (group_size == 64) { f.template operator()<64>(); } else if (group_size == 128) { f.template operator()<128>(); } else { throw std::invalid_argument( fmt::format("{} Group size {} is not supported.", tag, group_size)); } } template inline void dispatch_quant_types( int bits, int group_size, QuantizationMode mode, const char* tag, F&& f) { if (mode == QuantizationMode::Mxfp4) { f.template operator()(); } else if (mode == QuantizationMode::Mxfp8) { f.template operator()(); } else if (mode == QuantizationMode::Nvfp4) { f.template operator()(); } else { dispatch_groups(group_size, tag, [&]() { if (bits == 4) { f.template operator()(); } else if (bits == 8) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} {}-bit quantization is not supported.", tag, bits)); } }); } } template void qmm_impl_sm80( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; int m = out.shape(-2); int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( bits, group_size, mode, tag, [&]() { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); if (biases) { encoder.set_input_array(*biases); } encoder.set_output_array(out); cutlass_gemm::qmm_sm80( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, gpu_ptr(out), m, n, k, l, cute::Int{}, [&](auto* kernel, dim3 num_blocks, dim3 block_dims, uint32_t smem_bytes, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); }); }); }); } } // namespace mlx::core #define QMM_SM80_GPU(TileM) \ namespace mlx::core { \ template void qmm_impl_sm80( \ const array& x, \ const array& w, \ const array& scales, \ const std::optional& biases, \ array& out, \ int bits, \ int group_size, \ QuantizationMode mode, \ cu::CommandEncoder& encoder); \ } ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m16.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh" QMM_SM80_GPU(16) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m32.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh" QMM_SM80_GPU(32) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m64.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh" QMM_SM80_GPU(64) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/cutlass_utils.cuh" #include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include #include #include #include #include #include #if defined(MLX_CUDA_SM90A_ENABLED) // We can't put kernel code in mlx::core due to name conflicts of "Shape". namespace cutlass_gemm { using namespace cute; template < typename TileShapeMN = Shape<_128, _16>, typename ClusterShape = Shape<_1, _1, _1>, typename Element, typename Quant, typename GroupSize, typename F> void qmm_sm90( const Element* A, const Quant* B, const Element* S, const Element* Z, Element* D, int64_t m, int64_t n, int64_t k, int64_t l, GroupSize group_size, F&& launch_kernel) { constexpr int kAlignmentA = 128 / sizeof_bits::value; constexpr int kAlignmentB = 128 / sizeof_bits::value; constexpr int kTileShapeK = std::max(64, 128 * 8 / sizeof_bits::value); static_assert(group_size % kTileShapeK == 0); using Arch = cutlass::arch::Sm90; using Accumulator = float; using TileShape = decltype(append(TileShapeMN{}, Int{})); using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< Arch, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, Accumulator, Accumulator, // ElementC: void, cutlass::layout::ColumnMajor, kAlignmentA, // ElementD: Element, cutlass::layout::ColumnMajor, kAlignmentA, cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp; // Note that A/B are swapped and transposed to use TMA epilogue. using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder< Arch, cutlass::arch::OpClassTensorOp, // ElementA: tuple, cutlass::layout::RowMajor, kAlignmentB, // ElementB: Element, cutlass::layout::ColumnMajor, kAlignmentA, Accumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename Epilogue::SharedStorage))>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel:: GemmUniversal, Mainloop, Epilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; auto dA = make_stride(k, Int<1>{}, m * k); auto dB = make_stride(k, Int<1>{}, n * k); auto dS = make_stride(Int<1>{}, n, n * k / group_size); auto dD = make_stride(Int<1>{}, n, m * n); Gemm gemm; typename Gemm::Arguments args{ cutlass::gemm::GemmUniversalMode::kGemm, {int(n), int(m), int(k), int(l)}, {B, dB, A, dA, S, dS, group_size, Z}, {{1.f, 0.f}, D, dD, D, dD}}; CHECK_CUTLASS_ERROR(gemm.can_implement(args)); CHECK_CUTLASS_ERROR(gemm.initialize(args, nullptr)); auto* kernel = &cutlass::device_kernel; void* kernel_params[] = {const_cast(&gemm.params())}; auto cluster = ClusterShape{}; launch_kernel( reinterpret_cast(kernel), gemm.get_grid_shape(gemm.params()), GemmKernel::get_block_shape(), {static_cast(get<0>(cluster)), static_cast(get<1>(cluster)), static_cast(get<2>(cluster))}, GemmKernel::SharedStorageSize, kernel_params); } } // namespace cutlass_gemm namespace mlx::core { inline array transpose_last_2_dims( const array& x, cu::CommandEncoder& encoder, const Stream& s) { array transposed = swapaxes_in_eval(x, -1, -2); array transposed_copy = contiguous_copy_gpu(transposed, s); encoder.add_temporary(transposed_copy); return transposed_copy; } template inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { if (dtype == float32) { f.template operator()(); } else if (dtype == float16) { f.template operator()(); } else if (dtype == bfloat16) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); } } template inline void dispatch_quant_types(int bits, const char* tag, F&& f) { if (bits == 2) { f.template operator()(); } else if (bits == 4) { f.template operator()(); } else if (bits == 8) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} {}-bit quantization is not supported.", tag, bits)); } } template inline void dispatch_groups(int group_size, const char* tag, F&& f) { if (group_size == 64) { f(cute::Int<64>{}); } else if (group_size == 128) { f(cute::Int<128>{}); } else { throw std::invalid_argument( fmt::format("{} Group size {} is not supported.", tag, group_size)); } } template void qmm_impl_sm90( const array& x, const array& w, const array& scales_, const array& biases_, array& out, int bits, int group_size, cu::CommandEncoder& encoder, Stream s) { const char* tag = "[quantized_matmul]"; int m = out.shape(-2); int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); // FIXME: Copy happens for every call. array scales = transpose_last_2_dims(scales_, encoder, s); array biases = transpose_last_2_dims(biases_, encoder, s); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types(bits, tag, [&]() { dispatch_groups(group_size, tag, [&](auto group_size) { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); encoder.set_input_array(biases); encoder.set_output_array(out); cutlass_gemm::qmm_sm90( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), gpu_ptr(biases), gpu_ptr(out), m, n, k, l, group_size, [&](auto* kernel, dim3 num_blocks, dim3 block_dims, dim3 cluster_shape, uint32_t smem_bytes, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, cluster_shape, smem_bytes, args); }); }); }); }); } } // namespace mlx::core #define QMM_SM90_GPU(TileShapeMN, ClusterShape) \ namespace mlx::core { \ template void qmm_impl_sm90( \ const array& x, \ const array& w, \ const array& scales, \ const array& biases, \ array& out, \ int bits, \ int group_size, \ cu::CommandEncoder& encoder, \ Stream s); \ } #else #define QMM_SM90_GPU(TileShapeMN, ClusterShape) #endif // defined(MLX_CUDA_SM90A_ENABLED) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n128_m2.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh" using namespace cute; using TileShapeMN = Shape<_128, _128>; using ClusterShape = Shape<_2, _1, _1>; QMM_SM90_GPU(TileShapeMN, ClusterShape) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n16_m1.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh" using namespace cute; using TileShapeMN = Shape<_128, _16>; using ClusterShape = Shape<_1, _1, _1>; QMM_SM90_GPU(TileShapeMN, ClusterShape) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n256_m2.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh" using namespace cute; using TileShapeMN = Shape<_128, _256>; using ClusterShape = Shape<_2, _1, _1>; QMM_SM90_GPU(TileShapeMN, ClusterShape) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n32_m1.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh" using namespace cute; using TileShapeMN = Shape<_128, _32>; using ClusterShape = Shape<_1, _1, _1>; QMM_SM90_GPU(TileShapeMN, ClusterShape) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n64_m2.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh" using namespace cute; using TileShapeMN = Shape<_128, _64>; using ClusterShape = Shape<_2, _1, _1>; QMM_SM90_GPU(TileShapeMN, ClusterShape) ================================================ FILE: mlx/backend/cuda/quantized/qmm/qmv.cu ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/dtype_utils.h" #include #include #include #include namespace cutlass { using uint3b_t = integer_subbyte<3, false>; using uint5b_t = integer_subbyte<5, false>; template struct NumericArrayConverter { static_assert(N % 8 == 0); using result_type = Array; using source_type = Array; CUTLASS_HOST_DEVICE static result_type convert(const source_type& source) { result_type result; auto* s_base = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 8; ++i) { auto* s = s_base + i * 3; result[i * 8] = T(s[0] & 0x07); result[i * 8 + 1] = T((s[0] & 0x38) >> 3); result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2); result[i * 8 + 3] = T((s[1] & 0x0e) >> 1); result[i * 8 + 4] = T((s[1] & 0x70) >> 4); result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1); result[i * 8 + 6] = T((s[2] & 0x1c) >> 2); result[i * 8 + 7] = T((s[2] & 0xe0) >> 5); } return result; } CUTLASS_HOST_DEVICE result_type operator()(const source_type& s) const { return convert(s); } }; template struct NumericArrayConverter { static_assert(N % 8 == 0); using result_type = Array; using source_type = Array; CUTLASS_HOST_DEVICE static result_type convert(const source_type& source) { result_type result; auto* s_base = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 8; ++i) { auto* s = s_base + i * 5; result[i * 8] = T(s[0] & 0x1f); result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3); result[i * 8 + 2] = T((s[1] & 0x7c) >> 2); result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1); result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4); result[i * 8 + 5] = T((s[3] & 0x3e) >> 1); result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2); result[i * 8 + 7] = T((s[4] & 0xf8) >> 3); } return result; } CUTLASS_HOST_DEVICE result_type operator()(const source_type& s) const { return convert(s); } }; template struct NumericArrayConverter { static_assert(N % 4 == 0); using result_type = Array; using source_type = Array; CUTLASS_HOST_DEVICE static result_type convert(const source_type& source) { result_type result; auto* s_base = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 4; ++i) { auto* s = s_base + i * 3; result[i * 4] = T(s[0] & 0x3f); result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2); result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4); result[i * 4 + 3] = T((s[2] >> 2) & 0x3f); } return result; } CUTLASS_HOST_DEVICE result_type operator()(const source_type& s) const { return convert(s); } }; } // namespace cutlass namespace mlx::core { namespace cu { namespace cg = cooperative_groups; // Fused vectorized dequantize and multiply-add: // w_dq = w * scale + bias // out = fma(x, w_dq, out) template __device__ __forceinline__ void dequant_fma(const T* x, const Q* w, S scale, T bias, T* out) { // Read x/w into registers. auto x_vec = *(reinterpret_cast*>(x)); auto w_vec = *(reinterpret_cast*>(w)); // Output is assumed to be registers. auto* out_vec = reinterpret_cast*>(out); // Dequantize w. cutlass::NumericArrayConverter converter_tq; cutlass::Array w_dq = converter_tq(w_vec); if constexpr (has_bias) { if constexpr (cuda::std::is_same_v) { #pragma unroll for (int i = 0; i < N; ++i) { w_dq[i] = w_dq[i] * T(scale) + bias; } } else { w_dq = w_dq * T(scale) + bias; } } else { w_dq = w_dq * T(scale); } // Multiply and add. *out_vec = cutlass::fma(x_vec, w_dq, *out_vec); } // Specialization for doing float32 accumulations on narrow types. template < int N, bool has_bias, typename T, typename Q, typename S, typename = cuda::std::enable_if_t>> __device__ __forceinline__ void dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) { // Read x/w into registers. auto x_vec = *(reinterpret_cast*>(x)); auto w_vec = *(reinterpret_cast*>(w)); // Output is assumed to be registers. auto* out_vec = reinterpret_cast*>(out); // Dequantize w. cutlass::NumericArrayConverter converter_tq; cutlass::Array w_dq = converter_tq(w_vec); if constexpr (has_bias) { w_dq = w_dq * T(scale) + bias; } else { w_dq = w_dq * T(scale); } // Promote x/w to float. static_assert(!cuda::std::is_same_v); cutlass::NumericArrayConverter converter_ft; cutlass::Array x_f = converter_ft(x_vec); cutlass::Array w_f = converter_ft(w_dq); // Multiply and add. *out_vec = cutlass::fma(x_f, w_f, *out_vec); } template < int rows_per_block, int elems_per_thread, int group_size, bool has_bias, bool has_residue_k, typename T, typename Q, typename S> __global__ void qmv_kernel( const T* x, const Q* w, const S* scales, const T* biases, T* out, int n, int k, bool broadcast_w) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); // The row that this warp handles. int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); if (row >= n) { return; } // Advance pointers of x/out. int m = grid.dim_blocks().y; int l = block.group_index().z; x += block.group_index().y * k + m * k * l; out += block.group_index().y * n + m * n * l; // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would // move past 2 elements for 4-bit Q. constexpr int bits = cute::sizeof_bits_v; auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; }; // How many groups (and scales/biases) in a row. int groups_per_row = k / group_size; // Advance w/scales/biases to current row. int w_batch = broadcast_w ? 0 : l; w += (static_cast(row) + n * w_batch) * w_step(k); scales += (static_cast(row) + n * w_batch) * groups_per_row; if constexpr (has_bias) { biases += (static_cast(row) + n * w_batch) * groups_per_row; } // Accumulations of current row. cuda::std::conditional_t<(bits >= 8), float, T> sums[elems_per_thread] = {}; auto dequant_fma_tile = [&](int idx) { S scale = scales[idx / group_size]; T bias{0}; if constexpr (has_bias) { bias = biases[idx / group_size]; } dequant_fma( x + idx, w + w_step(idx), scale, bias, sums); }; // Loop over k dimension. constexpr int elems_per_warp = WARP_SIZE * elems_per_thread; for (int r = 0; r < k / elems_per_warp; ++r) { int idx = warp.thread_rank() * elems_per_thread + r * elems_per_warp; dequant_fma_tile(idx); } // Handle remaining elements in k dimension. if constexpr (has_residue_k) { int rest = k % elems_per_warp; int idx = warp.thread_rank() * elems_per_thread + k - rest; if (idx < k) { dequant_fma_tile(idx); } } // Result for current row. float sum{0}; #pragma unroll for (int i = 0; i < elems_per_thread; ++i) { sum += sums[i]; } sum = cg::reduce(warp, sum, cg::plus{}); // Write result for current warp, which maps to rows 1-to-1. if (warp.thread_rank() == 0) { out[row] = static_cast(sum); } } template < int group_size, bool has_bias, typename T, typename Q, typename S, typename F> void qmv( const T* x, const Q* w, const S* scales, const T* biases, T* out, int m, int n, int k, int l, bool broadcast_w, F&& launch_kernel) { constexpr int rows_per_block = 8; constexpr int elems_per_thread = (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 : 8; dim3 num_blocks{ uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)}; dim3 block_dims{WARP_SIZE, rows_per_block}; void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w}; dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { auto* kernel = &qmv_kernel< rows_per_block, elems_per_thread, group_size, has_bias, has_residue_k.value, T, Q, S>; launch_kernel( reinterpret_cast(kernel), num_blocks, block_dims, args); }); } } // namespace cu template inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { if (dtype == float32) { f.template operator()(); } else if (dtype == float16) { f.template operator()(); } else if (dtype == bfloat16) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); } } template inline void dispatch_groups(int group_size, const char* tag, F&& f) { if (group_size == 32) { f.template operator()<32>(); } else if (group_size == 64) { f.template operator()<64>(); } else if (group_size == 128) { f.template operator()<128>(); } else { throw std::invalid_argument( fmt::format("{} Group size {} is not supported.", tag, group_size)); } } template inline void dispatch_quant_types( int bits, int group_size, QuantizationMode mode, const char* tag, F&& f) { if (mode == QuantizationMode::Mxfp4) { f.template operator()(); } else if (mode == QuantizationMode::Mxfp8) { f.template operator()(); } else if (mode == QuantizationMode::Nvfp4) { f.template operator()(); } else { dispatch_groups(group_size, tag, [&]() { if (bits == 2) { f.template operator()(); } else if (bits == 3) { f.template operator()(); } else if (bits == 4) { f.template operator()(); } else if (bits == 5) { f.template operator()(); } else if (bits == 6) { f.template operator()(); } else if (bits == 8) { f.template operator()(); } else { throw std::invalid_argument( fmt::format("{} {}-bit quantization is not supported.", tag, bits)); } }); } } void qmv( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; int m = out.shape(-2); int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); bool broadcast_w = w.ndim() == 2; dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( bits, group_size, mode, tag, [&]() { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); if (biases) { encoder.set_input_array(*biases); } encoder.set_output_array(out); constexpr bool has_bias = !cutlass::has_negative_zero_v; cu::qmv( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, gpu_ptr(out), m, n, k, l, broadcast_w, [&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args); }); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qqmm.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/backend/cuda/quantized/qqmm_impl.h" #include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/primitives.h" #include namespace mlx::core { namespace { std::tuple quantize_input( const array& input, cu::CommandEncoder& encoder, const Stream& s, QuantizationMode mode, int bits, int group_size, std::optional global_scale = std::nullopt) { const array x = ensure_contiguous(input, encoder, s); // Compute output shapes auto xq_shape = x.shape(); xq_shape.back() = x.shape(-1) * bits / 32; const int64_t scales_inner = x.shape(-1) / group_size; auto [pad_outer, pad_inner] = get_padded_scale_dims(x.shape(-2), scales_inner); auto sshape = x.shape(); sshape[x.ndim() - 2] = pad_outer; sshape[x.ndim() - 1] = pad_inner; sshape.back() = scales_inner; // Allocate outputs const int64_t xq_bytes = x.size() * bits / 8; const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1)); const int64_t scales_bytes = batch * (pad_outer * pad_inner); array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); array scales_x( cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); encoder.add_temporary(x_q); encoder.add_temporary(scales_x); // global_scale is not nullopt only for NVFP4 fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s); return {std::move(x_q), std::move(scales_x)}; } GemmScalars create_nvfp4_scalars( const array& global_scale_x, const array& global_scale_w, cu::CommandEncoder& encoder) { // NVFP4 requires alpha/beta as device pointers // alpha = amax_x * amax_w / (448 * 6)^2 // beta = 0 array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32); array beta(cu::malloc_async(sizeof(float), encoder), {}, float32); compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder); encoder.add_temporary(alpha); encoder.add_temporary(beta); return {alpha, beta}; } } // namespace void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QQMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); bool w_quantized = (inputs[1].dtype() == uint32); int base_size = w_quantized ? 3 : 2; assert( inputs.size() == base_size || (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); // For nvfp4, get global scale for x from inputs if present bool has_global_scale = mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; std::optional global_scale = std::nullopt; if (has_global_scale) { global_scale = inputs[inputs.size() - 2]; } bool donate_x = inputs[0].is_donatable(); array x = ensure_row_contiguous(inputs[0], encoder, s); // If x is a copy it should be donatable donate_x |= x.is_donatable(); auto xhat = donate_x ? x : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype()); if (!donate_x) { encoder.add_temporary(xhat); } fp_quantize_dequantize( x, xhat, group_size_, bits_, global_scale, encoder, s); const array& w = inputs[1]; const array& scales = inputs[2]; qmv(xhat, w, scales, std::nullopt, out, bits_, group_size_, mode_, encoder); return; } auto cc = device.compute_capability_major() * 100 + device.compute_capability_minor() * 10; if (cc < 1000) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); } // - 2 inputs: x, w (non-quantized w) // - 3 inputs: x, w, scales_w (quantized w) // For nvfp4, global scales are optional but must be both present or both // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) bool has_global_scales = mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; // For nvfp4, get global scales from inputs if present std::optional global_scale_x = std::nullopt; std::optional global_scale_w = std::nullopt; if (has_global_scales) { global_scale_x = inputs[inputs.size() - 2]; global_scale_w = inputs[inputs.size() - 1]; } // Quantize inputs (or use pre-quantized) auto [x_q, scale_x_pre] = quantize_input( inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); auto [w_q, scale_w_pre] = !w_quantized ? quantize_input( inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) : std::make_tuple( ensure_contiguous(inputs[1], encoder, s), ensure_contiguous(inputs[2], encoder, s)); out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); int N = w_q.shape(-2); // transposed int K = x_q.shape(-1) * (32 / bits_); bool x_transposed = false; bool w_transposed = true; // always transposed int64_t lda = K; int64_t ldb = K; // Repack scales to tiled layout for tensor cores array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); GemmScalars scalars; if (has_global_scales) { scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); } qqmm_impl( encoder, M, N, K, x_transposed, lda, w_transposed, ldb, out, x_q, w_q, scale_x, scale_w, mode_, scalars); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qqmm_impl.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qqmm_impl.h" #include "mlx/backend/cuda/quantized/cublas_qqmm.h" namespace mlx::core { void qqmm_impl( cu::CommandEncoder& encoder, int M, int N, int K, bool a_transposed, int64_t lda, bool b_transposed, int64_t ldb, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale, QuantizationMode mode, const GemmScalars& scalars) { std::string qmode = quantization_mode_to_string(mode); CublasQQMM qqmm( encoder.device(), a_transposed, M, K, lda, b_transposed, K, N, ldb, 1, // batch_count 0, // a_batch_stride 0, // b_batch_stride out.dtype(), qmode); if (scalars.has_values()) { qqmm.run( encoder, out, a, b, a_scale, b_scale, *scalars.alpha_device, *scalars.beta_device); } else { qqmm.run(encoder, out, a, b, a_scale, b_scale); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qqmm_impl.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/primitives.h" #include namespace mlx::core { struct GemmScalars { std::optional alpha_device; std::optional beta_device; bool has_values() const { return alpha_device.has_value(); } }; void qqmm_impl( cu::CommandEncoder& encoder, int M, int N, int K, bool a_transposed, int64_t lda, bool b_transposed, int64_t ldb, array& out, const array& a, const array& b, const array& a_scale, const array& b_scale, QuantizationMode mode, const GemmScalars& scalars = {}); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qqmm_utils.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qqmm_utils.h" #include namespace mlx::core { namespace cg = cooperative_groups; constexpr int TILE_ROWS = 128; constexpr int TILE_COLS = 4; constexpr int TILES_PER_LANE = 1; constexpr int LANES_PER_BLOCK = 32; // To pass scales to tensor cores, they need to be repacked into a tiled layout // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Tiled layout for scale factors is very well described in CUTLASS // documentation: // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts // Conceptually, it should be like this: // q_w = mx.zeros(shape=(M, N)) <-- zeros just for an example // s.shape = (M, N // 16) -- packed in row contigous order, group_size = 16 // cbg_cnt = N // 16 // 4 // rb_cnt = M // 128 // tmp = x.reshape(rb_cnt, 4, 32, cbg_cnt, 4) // repacked_scales = tmp.transpose(0, 3, 2, 1, 4) // example: indecis of intial tile 128 x 4 of scales (packed in row major tensor // (M, K // 16), where M = 128, K = 64): array([[0, 1, 2, 3], // [4, 5, 6, 7], // [8, 9, 10, 11], // ..., // [500, 501, 502, 503], // [504, 505, 506, 507], // [508, 509, 510, 511]] // packed scales within tile 128 x 4: // array([[[[[0, 1, 2, 3], <-- s_0,0..s_0,3 scales // [128, 129, 130, 131], <-- s_32,0..s_32,3 scales // [256, 257, 258, 259], <-- s_64,0..s_64,3 scales // [384, 385, 386, 387]], <-- s_96,0..s_96,3 scales // [[4, 5, 6, 7], <-- s_1,0..s_1,3 scales // [132, 133, 134, 135], ... // [260, 261, 262, 263], // [388, 389, 390, 391]], // [[124, 125, 126, 127], // [252, 253, 254, 255], // [380, 381, 382, 383], // [508, 509, 510, 511]]]]], inline std::tuple get_swizzle_launch_args( size_t M_swizzled, size_t K_swizzled) { constexpr int tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE; constexpr int warps_per_block = TILE_ROWS / 4; // 128 / 4 = 32 const int num_tiles_k = K_swizzled / TILE_COLS; const int num_tiles_m = M_swizzled / TILE_ROWS; dim3 grid; grid.x = cuda::ceil_div(num_tiles_k, tiles_per_block); grid.y = num_tiles_m; grid.z = 1; // Block is always (32, 32) = 1024 threads dim3 block(LANES_PER_BLOCK, warps_per_block, 1); return std::make_tuple(grid, block); } namespace cu { constexpr float F8E4M3_MAX = 448.0f; constexpr float F4E2M1_MAX = 6.0f; __global__ void compute_qqmm_pointers( float* alpha_out, float* beta_out, const float* tensor_amax_x, const float* tensor_amax_w) { // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 constexpr float inv_scale_sq = 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX); *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq; *beta_out = 0.0f; } __global__ void swizzle_scales( const uint8_t* scales_linear, uint8_t* scales_swizzled, const size_t M, const size_t K, const size_t M_swizzled, const size_t K_swizzled) { constexpr int tile_size = TILE_ROWS * TILE_COLS; constexpr int num_tile_rows_per_thread = 4; constexpr int max_tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE; constexpr int tile_stride = tile_size / 16; // 32 int4s per tile // Each thread loads 16 scales from 4 rows (stride 32) and packs them into // int4. For example: thread (0, 0) loads scales at rows 0,32,64,96 of tile 0, // thread (1, 0) loads rows 0,32,64,96 of of tile 1, etc. // The store is strided within a warp (stride 32 int4s), so we first // write to shared memory, then do a coalesced store from shared to global auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); auto tidx = idx_in_block.x; auto tidy = idx_in_block.y; auto linear_tid = tidy * block_size.x + tidx; const int bid_x = block_idx.x; const int bid_y = block_idx.y; const int K_int = K_swizzled / 4; const size_t output_offset = static_cast(bid_y) * TILE_ROWS * K_int + static_cast(bid_x) * max_tiles_per_block * tile_size / 4; int* output_block = reinterpret_cast(scales_swizzled) + output_offset; const int grid_dim_x = cg::this_grid().dim_blocks().x; const int grid_dim_y = cg::this_grid().dim_blocks().y; int remaining = K_int - bid_x * max_tiles_per_block; int tiles_in_block = min(remaining, max_tiles_per_block); bool valid_tile = tidx * TILES_PER_LANE < tiles_in_block; __shared__ int4 strided_scales_thread[max_tiles_per_block * tile_stride]; // Initialize to zero for padding int thread_tile_rows[num_tile_rows_per_thread] = {0}; if (valid_tile) { const size_t col_base = static_cast(bid_x) * max_tiles_per_block * TILE_COLS + tidx * TILE_COLS; const bool aligned_k = (K % 4 == 0); if (aligned_k) { // fast path: K is aligned, use vectorized loads with stride K/4 const int K_stride = K / 4; const size_t block_offset = static_cast(bid_y) * TILE_ROWS * K_stride + static_cast(bid_x) * max_tiles_per_block; const int* input_block = reinterpret_cast(scales_linear) + block_offset; // load #pragma unroll for (int i = 0; i < num_tile_rows_per_thread; i++) { const size_t row = static_cast(bid_y) * TILE_ROWS + i * block_size.x + tidy; const int thread_offset = (i * block_size.x + tidy) * K_stride + tidx * TILES_PER_LANE; if (row < M && col_base + TILE_COLS <= K) { thread_tile_rows[i] = __ldg(input_block + thread_offset); } else if (row < M) { // partial tile at K boundary: load byte-by-byte #pragma unroll for (int c = 0; c < TILE_COLS; c++) { if (col_base + c < K) { reinterpret_cast(&thread_tile_rows[i])[c] = scales_linear[row * K + col_base + c]; } } } } } else { #pragma unroll for (int i = 0; i < num_tile_rows_per_thread; i++) { const size_t row = static_cast(bid_y) * TILE_ROWS + i * block_size.x + tidy; if (row < M) { const size_t row_start = row * K; #pragma unroll for (int c = 0; c < TILE_COLS; c++) { if (col_base + c < K) { reinterpret_cast(&thread_tile_rows[i])[c] = scales_linear[row_start + col_base + c]; } } } } } // store to shared with XOR swizzle to avoid bank conflicts int base_idx = tidx * tile_stride + tidy; int xor_bits = (tidy >> 3) & 0x3; int swizzled_idx = base_idx ^ xor_bits; strided_scales_thread[swizzled_idx] = *reinterpret_cast(thread_tile_rows); } cg::thread_block block = cg::this_thread_block(); cg::sync(block); const int total_int4s = tiles_in_block * tile_stride; #pragma unroll for (int i = linear_tid; i < total_int4s; i += block_size.x * block_size.y) { int tile_idx = i / tile_stride; int row_idx = i % tile_stride; int base_idx = tile_idx * tile_stride + row_idx; int xor_bits = (row_idx >> 3) & 0x3; int swizzled_idx = base_idx ^ xor_bits; reinterpret_cast(output_block)[i] = strided_scales_thread[swizzled_idx]; } } } // namespace cu void swizzle_scales( const array& scales, array& scales_tiled, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(scales); enc.set_output_array(scales_tiled); // Note: scales_tiled is padded to full tiles so if num_rows or num_cols // are not multiples of tile sizes size_t input_rows = scales.shape(-2); size_t input_cols = scales.shape(-1); size_t output_rows = scales_tiled.shape(-2); size_t output_cols = scales_tiled.shape(-1); auto [num_blocks, block_dims] = get_swizzle_launch_args(output_rows, output_cols); enc.add_kernel_node( cu::swizzle_scales, num_blocks, block_dims, gpu_ptr(scales), gpu_ptr(scales_tiled), input_rows, input_cols, output_rows, output_cols); } void compute_qqmm_pointers( array& alpha_out, array& beta_out, const array& tensor_amax_x, const array& tensor_amax_w, cu::CommandEncoder& enc) { enc.set_input_array(tensor_amax_x); enc.set_input_array(tensor_amax_w); enc.set_output_array(alpha_out); enc.set_output_array(beta_out); enc.add_kernel_node( cu::compute_qqmm_pointers, dim3(1), dim3(1), gpu_ptr(alpha_out), gpu_ptr(beta_out), gpu_ptr(tensor_amax_x), gpu_ptr(tensor_amax_w)); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/qqmm_utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/array.h" #include "mlx/backend/cuda/device.h" namespace mlx::core { // Compute padded dimensions for tiled layout // Tiles are 128 rows × 4 columns, must allocate full tiles inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { constexpr int rows_per_tile = 128; constexpr int cols_per_tile = 4; int padded_rows = ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile; int padded_cols = ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile; return {padded_rows, padded_cols}; } void swizzle_scales( const array& scales, array& scales_tiled, cu::CommandEncoder& enc, const Stream& s); inline array pad_and_swizzle_scales( const array& scale, cu::CommandEncoder& encoder, const Stream& s) { // Compute padded dimensions for full tiles (128 rows × 4 cols) auto [pad_outer, pad_inner] = get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) // 2. Out-of-bounds values must be filled with zeros // 3. Starting addresses must be 16-byte aligned // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( cu::malloc_async(pad_outer * pad_inner, encoder), Shape{pad_outer, pad_inner}, scale.dtype()); swizzle_scales(scale, scale_tiled, encoder, s); encoder.add_temporary(scale_tiled); return scale_tiled; } // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 // Allocate beta zero on device as well void compute_qqmm_pointers( array& alpha_out, array& beta_out, const array& tensor_amax_x, const array& tensor_amax_w, cu::CommandEncoder& enc); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/quantized.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include namespace mlx::core { void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QuantizedMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); const array& x = inputs[0]; const array& w = inputs[1]; const array& scales = inputs[2]; std::optional biases; if (inputs.size() > 3) { biases = inputs[3]; } auto supports = [&](auto&& f) { return f( x, w, scales, biases, out, transpose_, bits_, group_size_, mode_, encoder.device()); }; bool can_use_qmm_sm90 = supports(supports_qmm_sm90); bool can_use_qmm_sm80 = supports(supports_qmm_sm80); bool can_use_fp_qmv = supports(supports_fp_qmv); bool can_use_qmv = supports(supports_qmv) || can_use_fp_qmv; auto call_qmm_sm90 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); qmm_sm90(x, w, scales, *biases, out, bits_, group_size_, encoder, s); }; auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (can_use_fp_qmv) { fp_qmv(x, w, scales, out, bits_, group_size_, encoder, s); } else { qmv(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); } }; int M = out.shape(-2); int N = out.shape(-1); int K = x.shape(-1); int B = out.size() / (M * N); if (can_use_qmm_sm90) { if (can_use_qmv && (M == 1 && B == 1 && N <= 16384 && K <= 16384)) { call_qmv(); } else { call_qmm_sm90(); } return; } if (can_use_qmm_sm80) { if (can_use_qmv && (M * B < 8)) { call_qmv(); } else { call_qmm_sm80(); } return; } if (can_use_qmv) { call_qmv(); return; } throw std::runtime_error( fmt::format( "[quantized_matmul] No implementation for " "problem shape: {}x{}x{}x{}, transpose: {}, " "activation: {}, bits: {}, group size: {}, mode: \"{}\".", M, N, K, B, transpose_, dtype_to_string(x.dtype()), bits_, group_size_, quantization_mode_to_string(mode_))); } void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("Quantize::eval_gpu"); auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); auto& w = outputs[0]; w.set_data(cu::malloc_async(w.nbytes(), enc)); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4 bool use_global_scale = mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2; std::optional global_scale = use_global_scale ? std::make_optional(inputs[2]) : std::nullopt; fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s); } } else { auto w = ensure_contiguous(inputs[0], enc, s); auto& wq = outputs[0]; auto& scales = outputs[1]; wq.set_data(cu::malloc_async(wq.nbytes(), enc)); scales.set_data(cu::malloc_async(scales.nbytes(), enc)); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { bool use_global_scale = mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1; std::optional global_scale = use_global_scale ? std::make_optional(inputs[1]) : std::nullopt; fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s); } } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/quantized.h ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/cuda/device.h" namespace mlx::core { void affine_quantize( const array& w, array& wq, array& scales, array& biases, int group_size_, int bits_, cu::CommandEncoder& enc, const Stream& s); void affine_dequantize( const array& wq, const array& scales, const array& biases, array& w, int group_size_, int bits_, cu::CommandEncoder& enc, const Stream& s); void fp_quantize( const array& w, array& wq, array& scales, int group_size, int bits, const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); void fp_dequantize( const array& wq, const array& scales, array& w, int group_size, int bits, const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); void fp_quantize_dequantize( const array& w, array& what, int group_size, int bits, const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/quantized/quantized_utils.h ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/copy.h" namespace mlx::core { inline array ensure_row_contiguous( const array& x, cu::CommandEncoder& enc, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } else { return x; } } inline array ensure_row_contiguous_matrix( const array& x, cu::CommandEncoder& enc, const Stream& s) { if (x.ndim() < 2) { if (x.strides()[0] == 1) { return x; } } else { auto stride_0 = x.strides()[x.ndim() - 2]; auto stride_1 = x.strides()[x.ndim() - 1]; if (stride_0 == x.shape(-1) && stride_1 == 1) { return x; } } array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } inline array ensure_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) { if (x.flags().row_contiguous || x.flags().col_contiguous) { return x; } array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/random.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/primitives.h" #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; __constant__ constexpr uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; union rbits { uint2 val; uint8_t bytes[2][4]; }; __device__ rbits threefry2x32_hash(uint2 key, uint2 count) { uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; rbits v; v.val.x = count.x + ks[0]; v.val.y = count.y + ks[1]; for (int i = 0; i < 5; ++i) { for (auto r : rotations[i % 2]) { v.val.x += v.val.y; v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); v.val.y ^= v.val.x; } v.val.x += ks[(i + 1) % 3]; v.val.y += ks[(i + 2) % 3] + i + 1; } return v; } __global__ void rbitsc( const uint32_t* keys, uint8_t* out, dim3 grid_dims, bool odd, uint32_t bytes_per_key) { auto grid = cg::this_grid(); uint32_t thread_index = grid.thread_rank(); uint32_t index_x = thread_index % grid_dims.x; uint32_t index_y = thread_index / grid_dims.x; if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } auto kidx = 2 * index_x; auto key = uint2{keys[kidx], keys[kidx + 1]}; auto half_size = grid_dims.y - odd; out += index_x * bytes_per_key; bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[1][i]; } } } } __global__ void rbits( const uint32_t* keys, uint8_t* out, dim3 grid_dims, bool odd, uint32_t bytes_per_key, int32_t ndim, const __grid_constant__ Shape key_shape, const __grid_constant__ Strides key_strides) { auto grid = cg::this_grid(); uint32_t thread_index = grid.thread_rank(); uint32_t index_x = thread_index % grid_dims.x; uint32_t index_y = thread_index / grid_dims.x; if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } auto kidx = 2 * index_x; auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); auto key = uint2{keys[k1_elem], keys[k2_elem]}; auto half_size = grid_dims.y - odd; out += size_t(index_x) * bytes_per_key; bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[1][i]; } } } } } // namespace cu void RandomBits::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("RandomBits::eval_gpu"); assert(inputs.size() == 1); // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; size_t num_keys = keys.size() / 2; size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (out.size() == 0) { return; } size_t out_per_key = (bytes_per_key + 4 - 1) / 4; size_t half_size = out_per_key / 2; bool odd = out_per_key % 2; if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) { throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported"); } encoder.set_input_array(keys); encoder.set_output_array(out); int64_t total = num_keys * (half_size + odd); uint32_t threads_y = 1; while ((total / threads_y) >= UINT_MAX) { threads_y *= 2; } uint32_t threads_x = cuda::ceil_div(total, threads_y); dim3 grid_dims{ static_cast(num_keys), static_cast(half_size + odd)}; auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto& stream = encoder.stream(); if (keys.flags().row_contiguous) { encoder.add_kernel_node( cu::rbitsc, grid, block, gpu_ptr(keys), gpu_ptr(out), grid_dims, odd, bytes_per_key); } else { encoder.add_kernel_node( cu::rbits, grid, block, gpu_ptr(keys), gpu_ptr(out), grid_dims, odd, bytes_per_key, keys.ndim(), const_param(keys.shape()), const_param(keys.strides())); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/all_reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/reduce/reduce.cuh" #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { // TODO: Process multiple "rows" in each thread constexpr int M = 1; auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); const U init = cu::ReduceInit::value(); ReduceOp op; T vals[N]; U accs[M]; accs[0] = init; size_t start = grid.block_rank() * block_step; size_t end = start + block_step; size_t check = min(end, size); size_t i = start; for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { accs[0] = op(accs[0], cast_to(vals[j])); } } if (i < check) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { accs[0] = op(accs[0], cast_to(vals[i])); } } __shared__ U shared_accumulators[32]; block_reduce(block, warp, accs, shared_accumulators, op, init); if (block.thread_rank() == 0) { out[grid.block_rank()] = accs[0]; } } } // namespace cu void all_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type) { constexpr int N_READS = 8; out.set_data(cu::malloc_async(out.nbytes(), encoder)); auto get_args = [](int size, int N) { int threads = std::min(512, (size + N - 1) / N); threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int reductions_per_step = threads * N; size_t steps_needed = (size + reductions_per_step - 1) / reductions_per_step; int blocks; if (steps_needed < 32) { blocks = 1; } else if (steps_needed < 128) { blocks = 32; } else if (steps_needed < 512) { blocks = 128; } else if (steps_needed < 1024) { blocks = 512; } else { blocks = 1024; } size_t steps_per_block = (steps_needed + blocks - 1) / blocks; size_t block_step = steps_per_block * reductions_per_step; return std::make_tuple(blocks, threads, block_step); }; int blocks, threads; size_t block_step; size_t insize = in.size(); Dtype dt = in.dtype(); // Cub doesn't like const pointers for load (sigh). void* indata = const_cast(gpu_ptr(in)); // Large array so allocate an intermediate and accumulate there std::tie(blocks, threads, block_step) = get_args(insize, N_READS); encoder.set_input_array(in); if (blocks > 1) { array intermediate({blocks}, out.dtype(), nullptr, {}); intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder)); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; encoder.add_kernel_node( kernel, blocks, threads, static_cast(indata), gpu_ptr(intermediate), block_step, insize); }); }); // Set the input for the next step and recalculate the blocks indata = gpu_ptr(intermediate); dt = intermediate.dtype(); insize = intermediate.size(); std::tie(blocks, threads, block_step) = get_args(insize, N_READS); encoder.set_input_array(intermediate); } encoder.set_output_array(out); dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; encoder.add_kernel_node( kernel, blocks, threads, static_cast(indata), gpu_ptr(out), block_step, insize); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/col_reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/reduce/reduce.cuh" #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; struct ColReduceArgs { // The size of the contiguous column reduction. size_t reduction_size; int64_t reduction_stride; // Input shape and strides excluding the reduction axes. Shape shape; Strides strides; int ndim; // Input shape and strides of the reduction axes (including last dimension). Shape reduce_shape; Strides reduce_strides; int reduce_ndim; // The number of column we are reducing. Namely prod(reduce_shape). size_t non_col_reductions; ColReduceArgs( const array& in, const ReductionPlan& plan, const std::vector& axes) { using ShapeVector = decltype(plan.shape); using StridesVector = decltype(plan.strides); ShapeVector shape_vec; StridesVector strides_vec; assert(!plan.shape.empty()); reduction_size = plan.shape.back(); reduction_stride = plan.strides.back(); int64_t stride_back = 1; std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); while (!shape_vec.empty() && stride_back < reduction_stride) { stride_back *= shape_vec.back(); shape_vec.pop_back(); strides_vec.pop_back(); } std::vector indices(shape_vec.size()); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](int left, int right) { return strides_vec[left] > strides_vec[right]; }); ShapeVector sorted_shape; StridesVector sorted_strides; for (auto idx : indices) { sorted_shape.push_back(shape_vec[idx]); sorted_strides.push_back(strides_vec[idx]); } std::tie(shape_vec, strides_vec) = collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); reduce_shape = const_param(plan.shape); reduce_strides = const_param(plan.strides); reduce_ndim = plan.shape.size(); non_col_reductions = 1; for (int i = 0; i < reduce_ndim - 1; i++) { non_col_reductions *= reduce_shape[i]; } } }; template < typename T, typename U, typename Op, int NDIM, int BM, int BN, int N_READS = 4, int BLOCKS = 1> __global__ void col_reduce_looped( T* in, U* out, const __grid_constant__ ColReduceArgs args, int64_t out_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); constexpr int threads_per_row = BN / N_READS; // Compute the indices for the tile size_t tile_idx = grid.block_rank(); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); size_t tile_out = tile_y / out_size; tile_y = tile_y % out_size; // Compute the indices for the thread within the tile short thread_x = block.thread_rank() % threads_per_row; short thread_y = block.thread_rank() / threads_per_row; // Move the input pointer in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + tile_x * BN; // Initialize the running totals Op op; U totals[N_READS]; for (int i = 0; i < N_READS; i++) { totals[i] = ReduceInit::value(); } size_t total = args.non_col_reductions * args.reduction_size; size_t per_block, start, end; if constexpr (BLOCKS > 1) { per_block = (total + BLOCKS - 1) / BLOCKS; start = tile_out * per_block + thread_y; end = min((tile_out + 1) * per_block, total); } else { per_block = total; start = thread_y; end = total; } LoopedElemToLoc 2)> loop(args.reduce_ndim); loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); if (tile_x * BN + BN <= args.reduction_stride) { if (args.reduction_stride % N_READS == 0) { for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } } else { for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } } } else { for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlocked( thread_x, in + loop.location(), vals, args.reduction_stride - tile_x * BN, cast_to(ReduceInit::value())); for (int i = 0; i < N_READS; i++) { totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } } // Do warp reduce for each output. constexpr int n_outputs = BN / threads_per_row; static_assert(BM == 32 && n_outputs == N_READS); __shared__ U shared_vals[BM * BN]; short s_idx = thread_y * BN + thread_x * N_READS; for (int i = 0; i < N_READS; i++) { shared_vals[s_idx + i] = totals[i]; } block.sync(); s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; for (int i = 0; i < n_outputs; i++) { totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); } // Write result. if (warp.thread_rank() == 0) { if (BLOCKS > 1) { out += tile_out * out_size * args.reduction_stride; } cub::StoreDirectBlocked( warp.meta_group_rank(), out + tile_y * args.reduction_stride + tile_x * BN, totals, args.reduction_stride - tile_x * BN); } } template __global__ void col_reduce_small( const T* in, U* out, const __grid_constant__ ColReduceArgs args, size_t total) { Op op; auto grid = cg::this_grid(); auto block = cg::this_thread_block(); const auto idx = grid.thread_rank() * N_READS; const auto before_axis = idx / args.reduction_stride; const auto after_axis = idx % args.reduction_stride; const auto offset = before_axis * args.reduction_stride * args.reduction_size + after_axis; if (idx >= total) { return; } in += offset; out += idx; AlignedVector accumulator; for (int i = 0; i < N_READS; i++) { accumulator[i] = ReduceInit::value(); } for (int i = 0; i < args.reduction_size; i++) { auto values = load_vector(in, 0); for (int j = 0; j < N_READS; j++) { accumulator[j] = op(accumulator[j], cast_to(values[j])); } in += args.reduction_stride; } store_vector(out, 0, accumulator); } } // namespace cu inline auto output_grid_for_col_reduce( const array& out, const cu::ColReduceArgs& args, int bn, int outer = 1) { int gx, gy = 1; size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_outer_blocks = out.size() / args.reduction_stride; size_t n_blocks = n_outer_blocks * n_inner_blocks * outer; while (n_blocks / gy > INT32_MAX) { gy *= 2; } gx = cuda::ceil_div(n_blocks, gy); return dim3(gx, gy, 1); } void col_reduce_looped( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; auto kernel = cu::col_reduce_looped; encoder.add_kernel_node( kernel, grid, blocks, indata, gpu_ptr(out), static_cast(args), out.size() / args.reduction_stride); }); }); }); } void col_reduce_small( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; constexpr int N_READS = 16 / sizeof(T); auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides()); auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1); auto kernel = cu::col_reduce_small; encoder.add_kernel_node( kernel, grid, block, gpu_ptr(in), gpu_ptr(out), static_cast(args), out.size()); }); }); } void col_reduce_two_pass( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes, encoder); // Allocate an intermediate array to hold the 1st pass result constexpr int outer = 32; Shape intermediate_shape; intermediate_shape.push_back(outer); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); Strides intermediate_strides; intermediate_strides.push_back(out.size()); intermediate_strides.insert( intermediate_strides.end(), out.strides().begin(), out.strides().end()); array intermediate(intermediate_shape, out.dtype(), nullptr, {}); auto [data_size, rc, cc] = check_contiguity(intermediate_shape, intermediate_strides); auto fl = out.flags(); fl.row_contiguous = rc; fl.col_contiguous = cc; fl.contiguous = true; intermediate.set_data( cu::malloc_async(intermediate.nbytes(), encoder), data_size, intermediate_strides, fl, allocator::free); encoder.add_temporary(intermediate); encoder.set_input_array(in); encoder.set_output_array(intermediate); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN, outer); int blocks = BM * BN / N_READS; auto kernel = cu:: col_reduce_looped; encoder.add_kernel_node( kernel, grid, blocks, indata, gpu_ptr(intermediate), static_cast(args), out.size() / args.reduction_stride); }); }); }); // Prepare the reduction arguments for the 2nd pass cu::ColReduceArgs second_args = args; second_args.reduction_size = outer; second_args.reduction_stride = out.size(); second_args.ndim = 0; second_args.reduce_shape[0] = outer; second_args.reduce_strides[0] = out.size(); second_args.reduce_ndim = 1; second_args.non_col_reductions = 1; encoder.set_input_array(intermediate); encoder.set_output_array(out); dispatch_all_types(intermediate.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; constexpr int N_READS = 4; constexpr int BM = 32; constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, second_args, BN); int blocks = BM * BN / N_READS; auto kernel = cu::col_reduce_looped; encoder.add_kernel_node( kernel, grid, blocks, gpu_ptr(intermediate), gpu_ptr(out), second_args, second_args.reduction_stride); }); }); }); } void col_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { // Current col reduce options // // - col_reduce_looped // // It is a general strided reduce. Each threadblock computes the output for // a subrow of the fast moving axis. For instance 32 elements. // // - col_reduce_small // // It is a column reduce for small columns. Each thread loops over the whole // column without communicating with any other thread. // // - col_reduce_two_pass // // It is a reduce for long columns. To increase parallelism, we split the // reduction in two passes. First we do a column reduce where many // threadblocks operate on different parts of the reduced axis. Then we // perform a final column reduce. // // Notes: As in row reduce we opt to read as much in order as possible and // leave transpositions as they are (contrary to our Metal backend). // // Moreover we need different kernels for short rows and tuning // Make the args struct to help route to the best kernel cu::ColReduceArgs args(in, plan, axes); // Small col reduce with a single or contiguous reduction axis if (args.non_col_reductions == 1 && args.reduction_size <= 32 && args.reduction_stride % (16 / in.itemsize()) == 0) { col_reduce_small(encoder, in, out, reduce_type, axes, plan, args); return; } // Long column with smallish row size_t total_sums = args.non_col_reductions * args.reduction_size; size_t approx_threads = out.size(); if (total_sums / approx_threads > 32) { col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args); return; } // Fallback col reduce col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/init_reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/reduce/reduce.cuh" #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void init_reduce(U* out, size_t size) { auto index = cg::this_grid().thread_rank(); if (index < size) { out[index] = ReduceInit::value(); } } } // namespace cu void init_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type) { // Allocate if needed if (out.data_shared_ptr() == nullptr) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); } encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::init_reduce; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024; encoder.add_kernel_node(kernel, grid, block, gpu_ptr(out), out.size()); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/reduce.cuh ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/common/reduce.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" namespace mlx::core { template void dispatch_reduce_ndim(int ndim, F&& f) { if (ndim == 1) { f(std::integral_constant{}); } else if (ndim == 2) { f(std::integral_constant{}); } else { f(std::integral_constant{}); } } template void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { if (reduce_type == Reduce::ReduceType::And) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Or) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Sum) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Prod) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Max) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } } void all_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type); void row_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan); void col_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan); void init_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/reduce_ops.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/atomic_ops.cuh" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Or { __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Sum { template __device__ __forceinline__ T operator()(T a, T b) { return a + b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { atomic_add(x, y); } __device__ void atomic_update(int* x, int y) { atomic_add(x, y); } __device__ void atomic_update(float* x, float y) { atomic_add(x, y); } }; struct Prod { template __device__ __forceinline__ T operator()(T a, T b) { return a * b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Min { template __device__ __forceinline__ T operator()(T a, T b) { if constexpr (is_complex_v) { if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) { return a; } if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) { return b; } } else if constexpr (!cuda::std::is_integral_v) { if (cuda::std::isnan(a) || cuda::std::isnan(b)) { return cuda::std::numeric_limits::quiet_NaN(); } } return a < b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Max { template __device__ __forceinline__ T operator()(T a, T b) { if constexpr (is_complex_v) { if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) { return a; } if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) { return b; } } else if constexpr (!cuda::std::is_integral_v) { if (cuda::std::isnan(a) || cuda::std::isnan(b)) { return cuda::std::numeric_limits::quiet_NaN(); } } return a > b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; // Traits to get the result type of reduce op. template struct ReduceResult; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = T; }; template struct ReduceResult { using type = T; }; // Traits to get the init value of reduce op. template struct ReduceInit; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return true; } }; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return false; } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (is_complex_v) { return T{0, 0}; } else { return cast_to::type>(0); } } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (is_complex_v) { return T{1, 0}; } else { return cast_to::type>(1); } } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::max(); } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::min(); } }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/reduce/reduce_utils.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template struct uint_by_size; template <> struct uint_by_size<2> { using type = uint16_t; }; template <> struct uint_by_size<4> { using type = uint32_t; }; template <> struct uint_by_size<8> { using type = unsigned long long int; }; template __device__ void atomic_reduce(T* x, T y) { if constexpr (sizeof(T) == 1) { using U = uint16_t; U* x_int = (U*)((char*)x - ((size_t)x % 2)); int shift = ((char*)x - (char*)x_int) * 8; int mask = 0xff << shift; U old_val, new_val; do { old_val = *x_int; T result = Op{}(static_cast((old_val >> shift) & 0xff), y); new_val = (old_val & ~mask) | (result << shift); } while (atomicCAS(x_int, old_val, new_val) != old_val); } else { using U = typename uint_by_size::type; U* x_int = (U*)(x); U old_val, new_val; do { old_val = *x_int; T result = Op{}(*((T*)&old_val), y); new_val = *((U*)&result); } while (atomicCAS(x_int, old_val, new_val) != old_val); } } template inline __device__ void block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { // First reduce in the current warp for (int i = 0; i < N; i++) { vals[i] = cg::reduce(warp, vals[i], op); } // Reduce across warps if (warp.meta_group_size() > 1) { if (warp.thread_rank() == 0) { for (int i = 0; i < N; i++) { smem[warp.meta_group_rank() * N + i] = vals[i]; } } block.sync(); if (warp.thread_rank() < warp.meta_group_size()) { for (int i = 0; i < N; i++) { vals[i] = smem[warp.thread_rank() * N + i]; } } else { for (int i = 0; i < N; i++) { vals[i] = init; } } for (int i = 0; i < N; i++) { vals[i] = cg::reduce(warp, vals[i], op); } } } } // namespace cu inline void allocate_same_layout( array& out, const array& in, const std::vector& axes, cu::CommandEncoder& encoder) { if (in.flags().row_contiguous) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); return; } if (out.ndim() < in.ndim()) { throw std::runtime_error( "Reduction without keepdims only supported for row-contiguous inputs"); } // Calculate the transpositions applied to in in order to apply them to out. std::vector axis_order(in.ndim()); std::iota(axis_order.begin(), axis_order.end(), 0); std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { return in.strides(left) > in.strides(right); }); // Transpose the shape and calculate the strides Shape out_shape(in.ndim()); Strides out_strides(in.ndim(), 1); for (int i = 0; i < in.ndim(); i++) { out_shape[i] = out.shape(axis_order[i]); } for (int i = in.ndim() - 2; i >= 0; i--) { out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; } // Reverse the axis order to get the final strides Strides final_strides(in.ndim()); for (int i = 0; i < in.ndim(); i++) { final_strides[axis_order[i]] = out_strides[i]; } // Calculate the resulting contiguity and do the memory allocation auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); auto fl = in.flags(); fl.row_contiguous = rc; fl.col_contiguous = cc; fl.contiguous = true; out.set_data( cu::malloc_async(out.nbytes(), encoder), data_size, final_strides, fl, allocator::free); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce/row_reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/reduce/reduce.cuh" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; struct RowReduceArgs { // The size of the row being reduced, i.e. the size of last dimension. int row_size; // Input shape and strides excluding the reduction axes. Shape shape; Strides strides; int ndim; // Input shape and strides of the reduction axes excluding last dimension. Shape reduce_shape; Strides reduce_strides; int reduce_ndim; // The number of rows we are reducing. Namely prod(reduce_shape). size_t non_row_reductions; RowReduceArgs( const array& in, const ReductionPlan& plan, const std::vector& axes) { assert(!plan.shape.empty()); row_size = plan.shape.back(); auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); std::tie(shape_vec, strides_vec) = collapse_contiguous_dims(shape_vec, strides_vec); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); reduce_shape = const_param(plan.shape); reduce_strides = const_param(plan.strides); reduce_ndim = plan.shape.size() - 1; non_row_reductions = 1; for (int i = 0; i < reduce_ndim; i++) { non_row_reductions *= reduce_shape[i]; } } // Convert shape and strides as if in was contiguous void sort_access_pattern(const array& in, const std::vector& axes) { auto shape_vec = in.shape(); auto strides_vec = in.strides(); std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(shape_vec, strides_vec, axes); std::vector indices(shape_vec.size()); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](int left, int right) { return strides_vec[left] > strides_vec[right]; }); decltype(shape_vec) sorted_shape; decltype(strides_vec) sorted_strides; for (auto idx : indices) { sorted_shape.push_back(shape_vec[idx]); sorted_strides.push_back(strides_vec[idx]); } std::tie(shape_vec, strides_vec) = collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); } }; template __global__ void row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); const U init = cu::ReduceInit::value(); ReduceOp op; AlignedVector vals[M]; AlignedVector accs; for (int i = 0; i < M; i++) { accs[i] = init; } const size_t start_row = min(n_rows - M, static_cast(grid.block_rank() * M)); const size_t full_blocks = size / (block.size() * N); const size_t final_offset = full_blocks * (block.size() * N); in += start_row * size + block.thread_rank() * N; out += start_row; for (size_t r = 0; r < full_blocks; r++) { for (int k = 0; k < M; k++) { vals[k] = load_vector(in + k * size, 0); } for (int k = 0; k < M; k++) { for (int j = 0; j < N; j++) { accs[k] = op(accs[k], cast_to(vals[k][j])); } } in += block.size() * N; } if (final_offset < size) { for (int k = 0; k < M; k++) { for (int i = 0; i < N; i++) { vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) ? in[k * size + i] : cast_to(init); } } for (int k = 0; k < M; k++) { for (int j = 0; j < N; j++) { accs[k] = op(accs[k], cast_to(vals[k][j])); } } } __shared__ U shared_accumulators[32 * M]; block_reduce(block, warp, accs.val, shared_accumulators, op, init); if (block.thread_rank() == 0) { if (grid.block_rank() * M + M <= n_rows) { store_vector(out, 0, accs); } else { short offset = grid.block_rank() * M + M - n_rows; for (int i = offset; i < M; i++) { out[i] = accs[i]; } } } } template __global__ void row_reduce_looped( const T* in, U* out, const __grid_constant__ RowReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); size_t out_idx = grid.block_rank(); Op op; U total[1]; U init = ReduceInit::value(); total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); const size_t full_blocks = args.row_size / (block.size() * N_READS); const size_t final_offset = full_blocks * (block.size() * N_READS); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += block.thread_rank() * N_READS; // Unaligned reduce if (final_offset < args.row_size) { bool mask[N_READS]; for (int i = 0; i < N_READS; i++) { mask[i] = (final_offset + block.thread_rank() * N_READS + i) < args.row_size; } for (size_t n = 0; n < args.non_row_reductions; n++) { const T* inlocal = in + loop.location(); for (size_t r = 0; r < full_blocks; r++) { auto vals = load_vector(inlocal, 0); for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to(vals[i])); } inlocal += block.size() * N_READS; } { T vals[N_READS]; for (int i = 0; i < N_READS; i++) { vals[i] = mask[i] ? inlocal[i] : cast_to(init); } for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to(vals[i])); } } loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } } // Aligned case else { for (size_t n = 0; n < args.non_row_reductions; n++) { const T* inlocal = in + loop.location(); for (size_t r = 0; r < full_blocks; r++) { auto vals = load_vector(inlocal, 0); for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to(vals[i])); } inlocal += block.size() * N_READS; } loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } } __shared__ U shared_accumulators[32]; block_reduce(block, warp, total, shared_accumulators, op, init); if (block.thread_rank() == 0) { out[out_idx] = total[0]; } } } // namespace cu void row_reduce_simple( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { // Allocate data for the output using in's layout to avoid elem_to_loc in the // kernel. allocate_same_layout(out, in, axes, encoder); // TODO: If out.size() < 1024 which will be a common case then write this in // 2 passes. Something like 32 * out.size() and then do a warp reduce. encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; constexpr int N_READS = 16 / sizeof(T); // Calculate the grid and block dims size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel auto kernel = cu::row_reduce_simple; if (grid.x >= 1024) { grid.x = (grid.x + 1) / 2; kernel = cu::row_reduce_simple; } T* indata = const_cast(gpu_ptr(in)); int size = plan.shape.back(); encoder.add_kernel_node( kernel, grid, block, indata, gpu_ptr(out), out.size(), size); }); }); } void row_reduce_looped( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, cu::RowReduceArgs args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; constexpr int N_READS = 16 / sizeof(T); // Calculate the grid and block dims args.sort_access_pattern(in, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel auto kernel = cu::row_reduce_looped; dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { kernel = cu::row_reduce_looped; }); encoder.add_kernel_node( kernel, grid, block, gpu_ptr(in), gpu_ptr(out), args); }); }); } void row_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { // Current row reduction options // // - row_reduce_simple // // That means that we are simply reducing across the fastest moving axis. // We are reducing 1 or 2 rows per threadblock depending on the size of // output. // // - row_reduce_looped // // It is a general row reduction. We are computing 1 output per // threadblock. We read the fastest moving axis vectorized and loop over // the rest of the axes. // // Notes: We opt to read as much in order as possible and leave // transpositions as they are (contrary to our Metal backend). // Simple row reduce means that we have 1 axis that we are reducing over and // it has stride 1. if (plan.shape.size() == 1) { row_reduce_simple(encoder, in, out, reduce_type, axes, plan); return; } // Make the args struct to help route to the best kernel cu::RowReduceArgs args(in, plan, axes); // Fallback row reduce row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/reduce.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include #include namespace mlx::core { void Reduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Reduce::eval_gpu"); assert(inputs.size() == 1); array in = inputs[0]; // Make sure no identity reductions trickle down here. assert(!axes_.empty()); assert(out.size() != in.size()); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); if (in.size() == 0) { init_reduce(encoder, in, out, reduce_type_); return; } // Reduce. ReductionPlan plan = get_reduction_plan(in, axes_); // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. // // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting // like we do in Metal. When it comes to broadcasted reduction axes // some can be ignored eg for min/max. bool broadcasted = false; for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { if (j < axes_.size() && axes_[j] == i) { j++; } else { broadcasted = in.strides(i) == 0; } } if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { array in_copy = contiguous_copy_gpu(in, s); encoder.add_temporary(in_copy); in = in_copy; plan = get_reduction_plan(in, axes_); } if (plan.type == ContiguousAllReduce) { all_reduce(encoder, in, out, reduce_type_); return; } if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { row_reduce(encoder, in, out, reduce_type_, axes_, plan); return; } if (plan.type == ContiguousStridedReduce || plan.type == GeneralStridedReduce) { col_reduce(encoder, in, out, reduce_type_, axes_, plan); return; } throw std::runtime_error("No plan reached in reduce."); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/rms_norm.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; inline __device__ float2 plus_f2(const float2& a, const float2& b) { return {a.x + b.x, a.y + b.y}; } // Similar to cub::BlockReduce, but result is broadcasted to every thread. template struct BlockBroadcastReduce { using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)]; cg::thread_block& block; TempStorage& temp; template __device__ T Reduce(const T& input, const Op& op, const T& init_value) { auto warp = cg::tiled_partition(block); T x = cg::reduce(warp, input, op); if constexpr (BLOCK_DIM > GROUP_DIM) { if (warp.thread_rank() == 0) { temp[warp.meta_group_rank()] = x; } block.sync(); x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] : init_value; return cg::reduce(warp, x, op); } else { return x; } } __device__ T Sum(const T& input) { return Reduce(input, cg::plus{}, T{}); } }; template __global__ void rms_norm_small( const T* x, const T* w, T* out, float eps, uint32_t axis_size, uint32_t n_rows, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceT = BlockBroadcastReduce; __shared__ typename BlockReduceT::TempStorage temp; auto row = (grid.block_rank() * block.dim_threads().y) + block.thread_index().y; if (row >= n_rows) { return; } x += row * axis_size; out += row * axis_size; // Normalizer. float normalizer = 0; auto index = block.thread_index().x; auto xn = load_vector(x, index, axis_size, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]); normalizer += t * t; } normalizer = BlockReduceT{block, temp}.Sum(normalizer); normalizer = rsqrt(normalizer / axis_size + eps); // Outputs. auto wn = load_vector(w, index, axis_size, w_stride, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { float y = static_cast(xn[i]) * normalizer; xn[i] = wn[i] * static_cast(y); } store_vector(out, index, xn, axis_size); } template __global__ void rms_norm( const T* x, const T* w, T* out, float eps, uint32_t axis_size, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceT = BlockBroadcastReduce; __shared__ typename BlockReduceT::TempStorage temp; x += grid.block_rank() * axis_size; out += grid.block_rank() * axis_size; // Normalizer. float normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]); normalizer += t * t; } } normalizer = BlockReduceT{block, temp}.Sum(normalizer); normalizer = rsqrt(normalizer / axis_size + eps); // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); #pragma unroll for (int i = 0; i < N_READS; ++i) { float y = static_cast(xn[i]) * normalizer; xn[i] = wn[i] * static_cast(y); } store_vector(out, index, xn, axis_size); } } template < typename T, bool HAS_W, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4> __global__ void rms_norm_vjp_small( const T* x, const T* w, const T* g, T* gx, T* gw, float eps, int32_t axis_size, int32_t n_rows, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceF2 = BlockBroadcastReduce; __shared__ typename BlockReduceF2::TempStorage temp; auto row = (grid.block_rank() * block.dim_threads().y) + block.thread_index().y; if (row >= n_rows) { return; } x += row * axis_size; g += row * axis_size; gx += row * axis_size; gw += row * axis_size; // Normalizer. float2 factors = {}; auto index = block.thread_index().x; auto xn = load_vector(x, index, axis_size, T(0)); auto gn = load_vector(g, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float t = static_cast(xn[i]); float wi = wn[i]; float gi = gn[i]; float wg = wi * gi; factors = plus_f2(factors, {wg * t, t * t}); } factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {}); float meangwx = factors.x / axis_size; float normalizer = rsqrt(factors.y / axis_size + eps); float normalizer3 = normalizer * normalizer * normalizer; // Outputs. for (int i = 0; i < N_READS; i++) { float xi = xn[i]; float wi = wn[i]; float gi = gn[i]; xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); if constexpr (HAS_W) { wn[i] = static_cast(gi * xi * normalizer); } } store_vector(gx, index, xn, axis_size); if constexpr (HAS_W) { store_vector(gw, index, wn, axis_size); } } template __global__ void rms_norm_vjp( const T* x, const T* w, const T* g, T* gx, T* gw, float eps, int32_t axis_size, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); using BlockReduceF2 = BlockBroadcastReduce; __shared__ typename BlockReduceF2::TempStorage temp; x += grid.block_rank() * axis_size; g += grid.block_rank() * axis_size; gx += grid.block_rank() * axis_size; gw += grid.block_rank() * axis_size; // Normalizer. float2 factors = {}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); auto gn = load_vector(g, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float t = static_cast(xn[i]); float wi = wn[i]; float gi = gn[i]; float wg = wi * gi; factors = plus_f2(factors, {wg * t, t * t}); } } factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {}); float meangwx = factors.x / axis_size; float normalizer = rsqrt(factors.y / axis_size + eps); float normalizer3 = normalizer * normalizer * normalizer; // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); auto xn = load_vector(x, index, axis_size, T(0)); auto gn = load_vector(g, index, axis_size, T(0)); auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float xi = xn[i]; float wi = wn[i]; float gi = gn[i]; xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); if constexpr (HAS_W) { wn[i] = static_cast(gi * xi * normalizer); } } store_vector(gx, index, xn, axis_size); if constexpr (HAS_W) { store_vector(gw, index, wn, axis_size); } } } } // namespace cu namespace fast { bool RMSNorm::use_fallback(Stream s) { return s.device == Device::cpu; } template void dispatch_group_dim(int axis_size, F&& f) { if (axis_size <= n_per_thread * 8) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 16) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 32) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 32 * 2) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 32 * 4) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 32 * 8) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else if (axis_size <= n_per_thread * 32 * 16) { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } else { f(std::integral_constant{}, std::integral_constant(), std::integral_constant()); } } // TODO: There are duplicate code with backend/metal/normalization.cpp void RMSNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("RMSNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( cu::malloc_async(x.data_size() * x.itemsize(), encoder), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; const array x = set_output(inputs[0]); const array& w = inputs[1]; int32_t axis_size = x.shape().back(); int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); if (axis_size <= N_READS * 1024) { dispatch_group_dim( axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) { constexpr int block_dim = n_groups() * group_dim(); auto kernel = cu::rms_norm_small; auto n_blocks = (n_rows + groups_per_block() - 1) / groups_per_block(); encoder.add_kernel_node( kernel, n_blocks, {block_dim, groups_per_block()}, gpu_ptr(x), gpu_ptr(w), gpu_ptr(out), eps_, axis_size, n_rows, w_stride); }); } else { auto kernel = cu::rms_norm; encoder.add_kernel_node( kernel, n_rows, 1024, gpu_ptr(x), gpu_ptr(w), gpu_ptr(out), eps_, axis_size, w_stride); } }); } void RMSNormVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("RMSNormVJP::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { copied = false; return x; } copied = true; return contiguous_copy_gpu(x, s); }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); bool copied; auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; bool g_copied; auto g = check_input(inputs[2], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; // Check whether we had a weight. bool has_w = w.ndim() != 0; // Allocate space for the outputs. bool g_in_gx = false; if (donate_x) { gx.copy_shared_buffer(x); } else if (donate_g) { gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(cu::malloc_async(gx.nbytes(), encoder)); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); } int32_t axis_size = x.shape().back(); int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; // Allocate a temporary to store the gradients for w and allocate the output // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; if (has_w) { if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder)); encoder.add_temporary(gw_temp); } } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); if (axis_size <= N_READS * 1024) { dispatch_group_dim( axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) { constexpr int block_dim = group_dim() * n_groups(); auto kernel = cu::rms_norm_vjp_small< DataType, has_w_constant.value, block_dim, group_dim(), N_READS>; auto n_blocks = (n_rows + groups_per_block() - 1) / groups_per_block(); encoder.add_kernel_node( kernel, n_blocks, {block_dim, groups_per_block()}, gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), gpu_ptr(gx), gpu_ptr(gw_temp), eps_, axis_size, n_rows, w_stride); }); } else { auto kernel = cu::rms_norm_vjp; encoder.add_kernel_node( kernel, n_rows, 1024, gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), gpu_ptr(gx), gpu_ptr(gw_temp), eps_, axis_size, w_stride); } }); }); if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); } } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/rope.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include namespace mlx::core { namespace cu { template __device__ void rope_single_impl( const T* in, T* out, int32_t offset, float inv_freq, float scale, int64_t stride, uint2 pos, uint2 dims) { float L = scale * static_cast(offset); // Compute costheta, sintheta float theta = L * inv_freq; float costheta = cos(theta); float sintheta = sin(theta); // Compute the input and output indices uint32_t index_1, index_2; if (traditional) { index_1 = 2 * pos.x + pos.y * stride; index_2 = index_1 + 1; } else { index_1 = pos.x + pos.y * stride; index_2 = index_1 + dims.x; } // Read and write the output float x1 = static_cast(in[index_1]); float x2 = static_cast(in[index_2]); float rx1; float rx2; if (forward) { rx1 = x1 * costheta - x2 * sintheta; rx2 = x1 * sintheta + x2 * costheta; } else { rx1 = x2 * sintheta + x1 * costheta; rx2 = x2 * costheta - x1 * sintheta; } out[index_1] = static_cast(rx1); out[index_2] = static_cast(rx2); } template __global__ void rope_single( const T* in, T* out, const int32_t* offset, float scale, float base, int64_t stride, uint2 dims) { uint2 pos = make_uint2( blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); if (pos.x >= dims.x || pos.y >= dims.y) { return; } float d = static_cast(pos.x) / static_cast(dims.x); float inv_freq = exp2(-d * base); rope_single_impl( in, out, *offset, inv_freq, scale, stride, pos, dims); } template __global__ void rope_single_freqs( const T* in, T* out, const int32_t* offset, const float* freqs, float scale, int64_t stride, uint2 dims, int64_t freq_stride) { uint2 pos = make_uint2( blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); if (pos.x >= dims.x || pos.y >= dims.y) { return; } float inv_freq = 1.0 / freqs[freq_stride * pos.x]; rope_single_impl( in, out, *offset, inv_freq, scale, stride, pos, dims); } template __device__ void rope_impl( const T* in, T* out, const int* offset, float inv_freq, float scale, const cuda::std::array strides, const cuda::std::array out_strides, int64_t offset_stride, int n_head, uint3 pos, uint3 dims) { auto n_head_up = N * ((n_head + N - 1) / N); auto head_idx = static_cast((pos.z * N) % n_head_up); auto batch_idx = (pos.z * N) / n_head_up; auto batch_offset = offset[batch_idx * offset_stride]; float L = scale * static_cast(pos.y + batch_offset); auto mat_idx = batch_idx * n_head + head_idx; // Compute costheta, sintheta float theta = L * inv_freq; float costheta = cos(theta); float sintheta = sin(theta); // Compute the input and output indices size_t in_index_1, in_index_2; size_t out_index_1, out_index_2; if (traditional) { out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + mat_idx * out_strides[0]; out_index_2 = out_index_1 + 1; in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + mat_idx * out_strides[0]; out_index_2 = out_index_1 + dims.x * out_strides[2]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + dims.x * strides[2]; } for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); float rx1; float rx2; if (forward) { rx1 = x1 * costheta - x2 * sintheta; rx2 = x1 * sintheta + x2 * costheta; } else { rx1 = x2 * sintheta + x1 * costheta; rx2 = x2 * costheta - x1 * sintheta; } out[out_index_1] = static_cast(rx1); out[out_index_2] = static_cast(rx2); in_index_1 += strides[0]; in_index_2 += strides[0]; out_index_1 += out_strides[0]; out_index_2 += out_strides[0]; } } template __global__ void rope( const T* in, T* out, const int32_t* offset, float scale, float base, const __grid_constant__ cuda::std::array strides, const __grid_constant__ cuda::std::array out_strides, int64_t offset_stride, int n_head, uint3 dims) { uint3 pos = make_uint3( blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z); if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { return; } float d = static_cast(pos.x) / static_cast(dims.x); float inv_freq = exp2(-d * base); rope_impl( in, out, offset, inv_freq, scale, strides, out_strides, offset_stride, n_head, pos, dims); } template __global__ void rope_freqs( const T* in, T* out, const int32_t* offset, const float* freqs, float scale, float base, const __grid_constant__ cuda::std::array strides, const __grid_constant__ cuda::std::array out_strides, int64_t offset_stride, int n_head, uint3 dims, int64_t freq_stride) { uint3 pos = make_uint3( blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z); if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { return; } float inv_freq = 1.0 / freqs[freq_stride * pos.x]; rope_impl( in, out, offset, inv_freq, scale, strides, out_strides, offset_stride, n_head, pos, dims); } } // namespace cu namespace fast { bool RoPE::use_fallback(Stream s) { return s.device == Device::cpu; } void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("RoPE::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& in = inputs[0]; auto& offset = inputs[1]; auto& out = outputs[0]; cuda::std::array strides; cuda::std::array out_strides; bool donated = false; int ndim = in.ndim(); int B = in.shape(0); int T = in.shape(-2); int D = in.shape(-1); size_t mat_size = T * D; int dispatch_ndim = ndim; while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { dispatch_ndim--; } int N = 1; for (int i = 1; i < (ndim - 2); ++i) { N *= in.shape(i); } // We apply rope to less that the whole vector so copy to output and then // apply in-place. if (dims_ < D) { donated = true; auto ctype = (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; copy_gpu(in, out, ctype, s); strides[0] = mat_size; strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } // Either copy or apply in-place else if (in.flags().row_contiguous) { if (in.is_donatable()) { donated = true; out.copy_shared_buffer(in); } else { out.set_data(cu::malloc_async(out.nbytes(), encoder)); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs out.set_data(cu::malloc_async(out.nbytes(), encoder)); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated donated = true; copy_gpu(in, out, CopyType::General, s); strides[0] = mat_size; strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } out_strides[0] = mat_size; out_strides[1] = out.strides()[ndim - 2]; out_strides[2] = out.strides()[ndim - 1]; // Some flags to help us dispatch below bool single = in.flags().row_contiguous && B == 1 && T == 1; bool with_freqs = inputs.size() == 3; encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); if (with_freqs) { encoder.set_input_array(inputs[2]); } encoder.set_output_array(out); dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { dispatch_bool(traditional_, [&](auto traditional) { dispatch_bool(forward_, [&](auto forward) { using DataType = cuda_type_t; if (single && !with_freqs) { auto kernel = cu::rope_single; uint2 dims = make_uint2(dims_ / 2, N); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node( kernel, grid, block, gpu_ptr(donated ? out : in), gpu_ptr(out), gpu_ptr(offset), scale_, std::log2(base_), mat_size, dims); } else if (single) { auto kernel = cu::rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, N); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node( kernel, grid, block, gpu_ptr(donated ? out : in), gpu_ptr(out), gpu_ptr(offset), gpu_ptr(inputs[2]), scale_, mat_size, dims, inputs[2].strides(0)); } else if (with_freqs) { auto kernel = cu::rope_freqs; int n_per_thread = 4; uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); uint3 dims = make_uint3(dims_ / 2, T, dimz); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); int64_t offset_stride = 0; if (inputs[1].ndim() > 0) { offset_stride = inputs[1].strides()[0]; } encoder.add_kernel_node( kernel, grid, block, gpu_ptr(donated ? out : in), gpu_ptr(out), gpu_ptr(offset), gpu_ptr(inputs[2]), scale_, std::log2(base_), strides, out_strides, offset_stride, N, dims, inputs[2].strides(0)); } else { auto kernel = cu::rope; int n_per_thread = 4; uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); uint3 dims = make_uint3(dims_ / 2, T, dimz); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); int64_t offset_stride = 0; if (inputs[1].ndim() > 0) { offset_stride = inputs[1].strides()[0]; } encoder.add_kernel_node( kernel, grid, block, gpu_ptr(donated ? out : in), gpu_ptr(out), gpu_ptr(offset), scale_, std::log2(base_), strides, out_strides, offset_stride, N, dims); } }); }); }); } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/scaled_dot_product_attention.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" #include namespace mlx::core { namespace { array prepare_sdpa_input(const array& x, Stream s) { // SDPA kernel's requirements on inputs: // 1. last dim's stride be 1; // 2. pointer be aligned. if (x.strides(-1) != 1 || get_alignment(x) < 16) { array x_copy = contiguous_copy_gpu(x, s); auto& encoder = cu::get_command_encoder(s); encoder.add_temporary(x_copy); return x_copy; } return x; } array prepare_sdpa_sinks(const array& sinks, Stream s) { // cuDNN requires sinks to be float32. if (sinks.dtype() == float32) { return sinks; } array sinks_f32(sinks.shape(), float32, nullptr, {}); copy_gpu(sinks, sinks_f32, CopyType::Vector, s); auto& encoder = cu::get_command_encoder(s); encoder.add_temporary(sinks_f32); return sinks_f32; } void malloc_with_same_layout( cu::CommandEncoder& encoder, array& o, const array& q) { if (q.flags().row_contiguous) { o.set_data(cu::malloc_async(o.nbytes(), encoder)); return; } // fill_order = argsort(q.strides()) Shape fill_order(q.ndim()); std::iota(fill_order.begin(), fill_order.end(), 0); std::stable_sort( fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) { auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1; auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1; return s1 < s2; }); // Generate o_strides with fill_order Strides o_strides(q.ndim()); int64_t stride = 1; for (int i : fill_order) { o_strides[i] = stride; stride *= o.shape(i); } // o is a transposed contiguous array o.set_data( cu::malloc_async(o.nbytes(), encoder), o.size(), o_strides, {true, false, false}); } bool use_cudnn_for_decoding( const array& q, const array& k, const array& v, bool has_arr_mask) { if (q.shape(2) != 1) { return false; } if (has_arr_mask) { return false; } // The cuDNN SDPA is faster than vector kernel but for small sequence the // overhead would kill the advantage. constexpr int kv_cache_step = 256; // number is from mlx-lm if (k.shape(2) < kv_cache_step) { return false; } // When called during graph building the strides is not available, and we // rely on |supports_sdpa_vector| to decide whether to use fast sdpa since // we can fallback to |sdpa_vector|. if ((k.status() != array::evaluated) || (v.status() != array::evaluated)) { return false; } // Check if k/v are slices from fixed-size kv cache. auto is_slice = [](const array& kv) { // Get pre-sliced sequence length from strides, and check if the buffer // belongs to a contiguous kv cache. int64_t T_kv = kv.strides(1) / kv.strides(2); if (kv.size() / kv.shape(2) * T_kv != kv.buffer_size() / kv.itemsize()) { return false; } // It is possible to use heuristic to check slices, but for now just make // mlx-lm work. return T_kv % kv_cache_step == 0; }; return is_slice(k) && is_slice(v); } // Get original kv from slices, i.e. undo keys[..., :offset, :] array unslice_kv(const array& kv) { Shape shape = kv.shape(); shape[2] = /* T_kv */ kv.strides(1) / kv.strides(2); array copy(shape, kv.dtype(), nullptr, {}); copy.copy_shared_buffer( kv, make_contiguous_strides(shape), {true, true, false}, /* data_size */ kv.buffer_size() / kv.itemsize(), /* offset */ -kv.offset()); return copy; } constexpr int QKV_NDIM = 4; struct SDPACacheKey { int device_id; fe::DataType_t cudnn_dtype; std::array q_shape; std::array k_shape; std::array v_shape; std::array q_strides; std::array k_strides; std::array v_strides; bool do_causal; std::array mask_shape; std::array mask_strides; bool has_sinks; bool output_logsumexp; }; inline BytesKey build_sdpa_cache_key( cu::CommandEncoder& encoder, const array& q, const array& k, const array& v, bool do_causal, const std::optional& mask_arr, const std::optional& sinks, bool decoding = false, bool output_logsumexp = false) { BytesKey cache_key; cache_key.pod.device_id = encoder.device().cuda_device(); cache_key.pod.cudnn_dtype = dtype_to_cudnn_type(q.dtype()); cache_key.pod.q_shape = vector_key(q.shape()); cache_key.pod.k_shape = vector_key(k.shape()); cache_key.pod.v_shape = vector_key(v.shape()); cache_key.pod.q_strides = vector_key(q.strides()); cache_key.pod.k_strides = vector_key(k.strides()); cache_key.pod.v_strides = vector_key(v.strides()); cache_key.pod.do_causal = do_causal; cache_key.pod.has_sinks = sinks.has_value(); cache_key.pod.output_logsumexp = output_logsumexp; if (mask_arr) { cache_key.pod.mask_shape = vector_key(mask_arr->shape()); cache_key.pod.mask_strides = vector_key(mask_arr->strides()); } if (decoding) { int64_t T_kv = k.strides(1) / k.strides(2); cache_key.pod.k_shape[2] = T_kv; cache_key.pod.v_shape[2] = T_kv; cache_key.pod.k_strides.fill(0); cache_key.pod.v_strides.fill(0); } return cache_key; } auto& sdpa_cache() { static LRUBytesKeyCache cache( "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 256); return cache; } auto& sdpa_backward_cache() { static LRUBytesKeyCache cache( "MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64); return cache; } enum UIDS { Q, K, V, SCALE, BIAS, SINKS, SEQ_LEN_Q, SEQ_LEN_KV, O, STATS, // Backward graph: D_Q, D_K, D_V, D_O, }; DnnGraph build_sdpa_graph( cudnnHandle_t handle, const array& q, const array& k, const array& v, bool do_causal, const std::optional& mask_arr, const std::optional& sinks, const std::optional& seq_len_q, const std::optional& seq_len_kv, bool output_logsumexp, const array& o, const std::optional& stats) { DnnGraph graph(handle, q.dtype()); auto q_ = graph.tensor("Q", Q, q); auto k_ = graph.tensor("K", K, k); auto v_ = graph.tensor("V", V, v); auto options = fe::graph::SDPA_attributes() .set_name("sdpa_cudnn") .set_attn_scale(graph.scalar("Scale", SCALE, float32)) .set_generate_stats(output_logsumexp); if (do_causal) { options.set_causal_mask_bottom_right(do_causal); } if (mask_arr) { options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); } if (sinks) { options.set_sink_token(graph.tensor_4d("SINKS", SINKS, *sinks, 1)); } if (seq_len_q && seq_len_kv) { options.set_padding_mask(true); options.set_seq_len_q(graph.tensor("SEQ_LEN_Q", SEQ_LEN_Q, *seq_len_q)); options.set_seq_len_kv(graph.tensor("SEQ_LEN_KV", SEQ_LEN_KV, *seq_len_kv)); } auto [o_, stats_] = graph.sdpa(q_, k_, v_, options); graph.tensor(o_, O, o)->set_output(true); if (output_logsumexp) { graph.tensor(stats_, STATS, *stats)->set_output(true); } CHECK_CUDNN_FE_ERROR(graph.prepare()); graph.select_behavior_notes( {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); CHECK_CUDNN_FE_ERROR(graph.build()); return graph; } DnnGraph build_sdpa_backward_graph( cudnnHandle_t handle, const array& q, const array& k, const array& v, bool do_causal, const std::optional& mask_arr, const std::optional& sinks, const array& o, const array& d_o, const array& stats, array& d_q, array& d_k, array& d_v) { DnnGraph graph(handle, q.dtype()); auto q_ = graph.tensor("Q", Q, q); auto k_ = graph.tensor("K", K, k); auto v_ = graph.tensor("V", V, v); auto o_ = graph.tensor("O", O, o); auto d_o_ = graph.tensor("D_O", D_O, d_o); auto stats_ = graph.tensor("STATS", STATS, stats); auto options = fe::graph::SDPA_backward_attributes() .set_name("sdpa_backward_cudnn") .set_attn_scale(graph.scalar("Scale", SCALE, float32)); if (do_causal) { options.set_causal_mask_bottom_right(do_causal); } if (mask_arr) { options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); } if (sinks) { options.set_sink_token(graph.tensor_4d("SINKS", SINKS, *sinks, 1)); } auto [d_q_, d_k_, d_v_] = graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options); graph.tensor(d_q_, D_Q, d_q)->set_output(true); graph.tensor(d_k_, D_K, d_k)->set_output(true); graph.tensor(d_v_, D_V, d_v)->set_output(true); CHECK_CUDNN_FE_ERROR(graph.prepare()); graph.select_behavior_notes( {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); CHECK_CUDNN_FE_ERROR(graph.build()); return graph; } } // namespace bool supports_sdpa_cudnn( const array& q, const array& k, const array& v, bool has_arr_mask, bool do_causal, Stream s) { static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SDPA", 1); if (!enabled) { return false; } // cuDNN SDPA requires Ampere and later. if (cu::device(s.device).compute_capability_major() < 8) { return false; } // Only use cuDNN for decoding when k/v are slices from fixed-size kv cache. if ((q.shape(2) == 1) && !use_cudnn_for_decoding(q, k, v, has_arr_mask)) { return false; } // cuDNN does not support bottom right mask when T_q > T_kv. if (do_causal && (q.shape(2) > k.shape(2))) { return false; } // D_qk and D_v must be a multiple of 8 with maximum value 128. if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) || (v.shape(-1) > 128)) { return false; } Dtype dtype = q.dtype(); return dtype == float16 || dtype == bfloat16; } void sdpa_cudnn( const array& q, array k, array v, float scale, array& o, std::optional& stats, bool do_causal, const std::optional& mask_arr, const std::optional& sinks, bool output_logsumexp, Stream s) { auto& encoder = cu::get_command_encoder(s); auto handle = encoder.device().get_cudnn_handle(); malloc_with_same_layout(encoder, o, q); // For decoding, unslice k/v and apply padding mask. std::optional seq_len_q; std::optional seq_len_kv; bool decoding = use_cudnn_for_decoding(q, k, v, mask_arr.has_value()); if (decoding) { int B = q.shape(0); std::vector seq_len_q_vec(B, q.shape(2)); std::vector seq_len_kv_vec(B, k.shape(2)); seq_len_q = array(seq_len_q_vec.begin(), {B, 1, 1, 1}); seq_len_kv = array(seq_len_kv_vec.begin(), {B, 1, 1, 1}); encoder.add_temporary(*seq_len_q); encoder.add_temporary(*seq_len_kv); k = unslice_kv(k); v = unslice_kv(v); encoder.add_temporary(k); encoder.add_temporary(v); } encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); encoder.set_output_array(o); if (mask_arr) { encoder.set_input_array(*mask_arr); } if (sinks) { encoder.set_input_array(*sinks); } if (seq_len_q && seq_len_kv) { encoder.set_input_array(*seq_len_q); encoder.set_input_array(*seq_len_kv); } if (output_logsumexp) { stats->set_data(cu::malloc_async(stats->nbytes(), encoder)); encoder.set_output_array(*stats); } // Search cache. auto cache_key = build_sdpa_cache_key( encoder, q, k, v, do_causal, mask_arr, sinks, decoding, output_logsumexp); auto it = sdpa_cache().find(cache_key); if (it == sdpa_cache().end()) { auto graph = build_sdpa_graph( handle, q, k, v, do_causal, mask_arr, sinks, seq_len_q, seq_len_kv, output_logsumexp, o, stats); it = sdpa_cache().emplace(cache_key, std::move(graph)).first; } auto& graph = it->second; std::unordered_map variant_pack{ {Q, gpu_ptr(q)}, {K, gpu_ptr(k)}, {V, gpu_ptr(v)}, {SCALE, &scale}, {O, gpu_ptr(o)}}; if (mask_arr) { variant_pack[BIAS] = gpu_ptr(*mask_arr); } if (sinks) { variant_pack[SINKS] = gpu_ptr(*sinks); } if (seq_len_q && seq_len_kv) { variant_pack[SEQ_LEN_Q] = gpu_ptr(*seq_len_q); variant_pack[SEQ_LEN_KV] = gpu_ptr(*seq_len_kv); } if (output_logsumexp) { variant_pack[STATS] = gpu_ptr(*stats); } CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack))); } void sdpa_backward_cudnn( const array& q, const array& k, const array& v, float scale, const array& o, const array& stats, bool do_causal, const std::optional& mask_arr, const std::optional& sinks, const array& d_o, array& d_q, array& d_k, array& d_v, Stream s) { auto& encoder = cu::get_command_encoder(s); auto handle = encoder.device().get_cudnn_handle(); malloc_with_same_layout(encoder, d_q, q); malloc_with_same_layout(encoder, d_k, k); malloc_with_same_layout(encoder, d_v, v); encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); encoder.set_input_array(o); encoder.set_input_array(stats); encoder.set_input_array(d_o); encoder.set_output_array(d_q); encoder.set_output_array(d_k); encoder.set_output_array(d_v); if (mask_arr) { encoder.set_input_array(*mask_arr); } if (sinks) { encoder.set_input_array(*sinks); } // Search cache. auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr, sinks); auto it = sdpa_backward_cache().find(cache_key); if (it == sdpa_backward_cache().end()) { auto graph = build_sdpa_backward_graph( handle, q, k, v, do_causal, mask_arr, sinks, o, d_o, stats, d_q, d_k, d_v); it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first; } auto& graph = it->second; std::unordered_map variant_pack{ {Q, gpu_ptr(q)}, {K, gpu_ptr(k)}, {V, gpu_ptr(v)}, {SCALE, &scale}, {O, gpu_ptr(o)}, {STATS, gpu_ptr(stats)}, {D_O, gpu_ptr(d_o)}, {D_Q, gpu_ptr(d_q)}, {D_K, gpu_ptr(d_k)}, {D_V, gpu_ptr(d_v)}}; if (mask_arr) { variant_pack[BIAS] = gpu_ptr(*mask_arr); } if (sinks) { variant_pack[SINKS] = gpu_ptr(*sinks); } CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack))); } // Defined in scaled_dot_product_attention.cu file. bool supports_sdpa_vector( const array& q, const array& k, const array& v, bool has_arr_mask, bool output_logsumexp); void sdpa_vector( const array& q, const array& k, const array& v, float scale, array& o, bool do_causal, const std::optional& sinks, Stream s); namespace fast { bool ScaledDotProductAttention::use_fallback( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, bool do_causal, bool is_training, bool output_logsumexp, Stream s) { if (s.device == Device::cpu) { return true; } return !supports_sdpa_cudnn(q, k, v, has_arr_mask, do_causal, s) && !supports_sdpa_vector(q, k, v, has_arr_mask, output_logsumexp); } bool ScaledDotProductAttention::supports_bool_mask() { return false; } void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); auto& s = stream(); array q = prepare_sdpa_input(inputs[0], s); array k = prepare_sdpa_input(inputs[1], s); array v = prepare_sdpa_input(inputs[2], s); array& out = outputs[0]; bool has_mask = inputs.size() - has_sinks_ > 3; bool has_arr_mask = has_mask && !do_causal_; std::optional mask_arr; if (has_arr_mask) { mask_arr = prepare_sdpa_input(inputs[3], s); } std::optional sinks; if (has_sinks_) { sinks = inputs.back(); } std::optional stats; if (output_logsumexp_) { stats = outputs[1]; } if (supports_sdpa_cudnn(q, k, v, has_arr_mask, do_causal_, s)) { if (sinks) { sinks = prepare_sdpa_sinks(*sinks, s); } sdpa_cudnn( q, k, v, scale_, out, stats, do_causal_, mask_arr, sinks, output_logsumexp_, s); } else { sdpa_vector(q, k, v, scale_, out, do_causal_, sinks, s); } } bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { // The frontend adds a padding mask when sequence length is not a multiple of // tile size. if (q.shape(2) % 128 != 0) { return true; } return s.device == Device::cpu; } void ScaledDotProductAttentionVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu"); auto& s = stream(); assert(inputs.size() >= 6); int primals_size = inputs.size() - 3; bool has_arr_mask = primals_size > 3 + has_sinks_; array q = prepare_sdpa_input(inputs[0], s); array k = prepare_sdpa_input(inputs[1], s); array v = prepare_sdpa_input(inputs[2], s); array o = prepare_sdpa_input(inputs[primals_size], s); array stats = prepare_sdpa_input(inputs[primals_size + 1], s); array d_o = prepare_sdpa_input(inputs[primals_size + 2], s); std::optional mask_arr; if (has_arr_mask) { mask_arr = prepare_sdpa_input(inputs[3], s); } std::optional sinks; if (has_sinks_) { sinks = prepare_sdpa_sinks(inputs.back(), s); } assert(outputs.size() == 3); auto& d_q = outputs[0]; auto& d_k = outputs[1]; auto& d_v = outputs[2]; sdpa_backward_cudnn( q, k, v, scale_, o, stats, do_causal_, mask_arr, sinks, d_o, d_q, d_k, d_v, s); } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/scaled_dot_product_attention.cu ================================================ // Copyright © 2025 Apple Inc. // Required for using M_LOG2E in MSVC. #define _USE_MATH_DEFINES #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; #define PRAGMA_LOOP_UNROLL #pragma unroll struct AttnParams { int B; int H; int D; int qL; int kL; int gqa_factor; float scale; int64_t Q_strides[3]; int64_t K_strides[3]; int64_t V_strides[3]; int64_t O_strides[3]; }; template __global__ void kernel_sdpav_1pass( const T* Q, const T* K, const T* V, T* O, const T* sinks, __grid_constant__ const AttnParams params) { constexpr int BN = 32; constexpr int BD = 32; constexpr int v_per_thread = D / BD; const int inner_k_stride = BN * int(params.K_strides[2]); const int inner_v_stride = BN * int(params.V_strides[2]); typedef float U; U q[v_per_thread]; U k[v_per_thread]; U o[v_per_thread]; __shared__ U outputs[BN][BD + 1]; __shared__ U max_scores[BN]; __shared__ U sum_exp_scores[BN]; const U scale_log2 = params.scale * M_LOG2E; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); const int lane_idx = warp.thread_rank(); const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread const int batch_idx = blockIdx.z; const int head_idx = blockIdx.x; const int kv_head_idx = head_idx / params.gqa_factor; const int q_seq_idx = blockIdx.y; const int kv_seq_idx = warp_idx; Q += batch_idx * params.Q_strides[0] + // Batch head_idx * params.Q_strides[1] + // Head q_seq_idx * params.Q_strides[2]; // Sequence K += batch_idx * params.K_strides[0] + // Batch kv_head_idx * params.K_strides[1] + // Head kv_seq_idx * params.K_strides[2]; // Sequence V += batch_idx * params.V_strides[0] + // Batch kv_head_idx * params.V_strides[1] + // Head kv_seq_idx * params.V_strides[2]; // Sequence O += batch_idx * params.O_strides[0] + // Batch head_idx * params.O_strides[1] + // Head q_seq_idx * params.O_strides[2]; // Sequence // Read the query and 0 the output accumulator PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { q[i] = scale_log2 * static_cast(Q[v_per_thread * lane_idx + i]); } PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { o[i] = 0.f; } U max_score = Limits::finite_min(); U sum_exp_score = 0.f; if (sinks && warp_idx == 0) { max_score = M_LOG2E * static_cast(sinks[head_idx]); sum_exp_score = 1.f; } // For each key for (int i = kv_seq_idx; i < params.kL; i += BN) { bool use_key = true; if constexpr (do_causal) { use_key = i <= (params.kL - params.qL + q_seq_idx); } if (use_key) { // Read the key PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { k[j] = K[v_per_thread * lane_idx + j]; } // Compute the i-th score U score = 0.f; PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { score += q[j] * k[j]; } // Warp sum score = cg::reduce(warp, score, cg::plus()); // Update the accumulators U new_max = max(max_score, score); U factor = exp2f(max_score - new_max); U exp_score = exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { o[j] = o[j] * factor + exp_score * static_cast(V[v_per_thread * lane_idx + j]); } } // Move the pointers to the next kv K += inner_k_stride; V += inner_v_stride; } if (lane_idx == 0) { max_scores[warp_idx] = max_score; sum_exp_scores[warp_idx] = sum_exp_score; } block.sync(); max_score = max_scores[lane_idx]; U new_max = cg::reduce(warp, max_score, cg::greater()); U factor = exp2f(max_score - new_max); sum_exp_score = cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus()); sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score); // Now we need to aggregate all the outputs PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { outputs[lane_idx][warp_idx] = o[i]; block.sync(); U ot = outputs[warp_idx][lane_idx] * factor; o[i] = cg::reduce(warp, ot, cg::plus()) * sum_exp_score; block.sync(); } // And write the output if (lane_idx == 0) { PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { O[v_per_thread * warp_idx + i] = static_cast(o[i]); } } } template __global__ void kernel_sdpav_2pass_1( const T* Q, const T* K, const T* V, const T* sinks, float* partials, float* sums, float* maxs, __grid_constant__ const AttnParams params) { constexpr int BN = 8; constexpr int BD = 32; constexpr int blocks = 32; constexpr int v_per_thread = D / BD; const int inner_k_stride = blocks * BN * int(params.K_strides[2]); const int inner_v_stride = blocks * BN * int(params.V_strides[2]); typedef float U; U q[v_per_thread]; U k[v_per_thread]; U o[v_per_thread]; __shared__ U outputs[BN][BD + 1]; __shared__ U max_scores[BN]; __shared__ U sum_exp_scores[BN]; const U scale_log2 = params.scale * 1.44269504089f; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); const int lane_idx = warp.thread_rank(); const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread const int batch_idx = blockIdx.z / blocks; const int block_idx = blockIdx.z % blocks; const int head_idx = blockIdx.x; const int kv_head_idx = head_idx / params.gqa_factor; const int q_seq_idx = blockIdx.y; const int kv_seq_idx = block_idx * BN + warp_idx; Q += batch_idx * params.Q_strides[0] + // Batch head_idx * params.Q_strides[1] + // Head q_seq_idx * params.Q_strides[2]; // Sequence K += batch_idx * params.K_strides[0] + // Batch kv_head_idx * params.K_strides[1] + // Head kv_seq_idx * params.K_strides[2]; // Sequence V += batch_idx * params.V_strides[0] + // Batch kv_head_idx * params.V_strides[1] + // Head kv_seq_idx * params.V_strides[2]; // Sequence const int p_stride_s = blocks; const int p_stride_h = params.qL * p_stride_s; const int p_stride_b = params.H * p_stride_h; const int p_offset = batch_idx * p_stride_b + // Batch head_idx * p_stride_h + // Head q_seq_idx * p_stride_s + // Sequence block_idx; // Block partials += p_offset * D; sums += p_offset; maxs += p_offset; // Read the query and 0 the output accumulator PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { q[i] = scale_log2 * static_cast(Q[v_per_thread * lane_idx + i]); } PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { o[i] = 0.f; } U max_score = Limits::finite_min(); U sum_exp_score = 0.f; if (sinks && warp_idx == 0 && block_idx == 0) { max_score = M_LOG2E * static_cast(sinks[head_idx]); sum_exp_score = 1.f; } // For each key for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { bool use_key = true; if constexpr (do_causal) { use_key = i <= (params.kL - params.qL + q_seq_idx); } if (use_key) { // Read the key PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { k[j] = K[v_per_thread * lane_idx + j]; } // Compute the i-th score U score = 0.f; PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { score += q[j] * k[j]; } // Warp sum score = cg::reduce(warp, score, cg::plus()); // Update the accumulators U new_max = max(max_score, score); U factor = exp2f(max_score - new_max); U exp_score = exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { o[j] = o[j] * factor + exp_score * static_cast(V[v_per_thread * lane_idx + j]); } } // Move the pointers to the next kv K += inner_k_stride; V += inner_v_stride; } if (lane_idx == 0) { max_scores[warp_idx] = max_score; sum_exp_scores[warp_idx] = sum_exp_score; } block.sync(); max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9; U new_max = cg::reduce(warp, max_score, cg::greater()); U factor = exp2f(max_score - new_max); sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f; sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus()); // Write the sum and new max if (warp_idx == 0) { sums[0] = sum_exp_score; maxs[0] = new_max; } // Now we need to aggregate all the outputs auto ff = exp2f(max_scores[warp_idx] - new_max); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { outputs[warp_idx][lane_idx] = o[i] * ff; block.sync(); if (warp_idx == 0) { U ot = outputs[0][lane_idx]; PRAGMA_LOOP_UNROLL for (int j = 1; j < BN; j++) { ot += outputs[j][lane_idx]; warp.sync(); } o[i] = ot; } block.sync(); } if (warp_idx == 0) { PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { partials[v_per_thread * lane_idx + i] = o[i]; } } } template __global__ void kernel_sdpav_2pass_2( const float* partials, const float* sums, const float* maxs, T* O, __grid_constant__ const AttnParams params) { constexpr int BN = 32; constexpr int BD = 32; constexpr int blocks = 32; constexpr int v_per_thread = D / BD; typedef float U; U o[v_per_thread]; __shared__ U outputs[BN][BD + 1]; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); const int lane_idx = warp.thread_rank(); const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread const int batch_idx = blockIdx.z; const int head_idx = blockIdx.x; const int q_seq_idx = blockIdx.y; const int p_stride_s = blocks; const int p_stride_h = params.qL * p_stride_s; const int p_stride_b = params.H * p_stride_h; const int p_offset = batch_idx * p_stride_b + // Batch head_idx * p_stride_h + // Head q_seq_idx * p_stride_s; // Sequence partials += p_offset * D + warp_idx * D; sums += p_offset; maxs += p_offset; O += batch_idx * params.O_strides[0] + // Batch head_idx * params.O_strides[1] + // Head q_seq_idx * params.O_strides[2]; // Sequence U max_score = maxs[lane_idx]; U new_max = cg::reduce(warp, max_score, cg::greater()); U factor = exp2f(max_score - new_max); U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus()); sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { o[i] = partials[v_per_thread * lane_idx + i]; } // Now we need to aggregate all the outputs PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { outputs[lane_idx][warp_idx] = o[i]; block.sync(); U ot = outputs[warp_idx][lane_idx] * factor; o[i] = cg::reduce(warp, ot, cg::plus()) * sum_exp_score; block.sync(); } // And write the output if (lane_idx == 0) { PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { O[v_per_thread * warp_idx + i] = static_cast(o[i]); } } } } // namespace cu namespace { template void dispatch_headdim(int n, F&& f) { switch (n) { case 64: f(std::integral_constant{}); break; case 96: f(std::integral_constant{}); break; case 128: f(std::integral_constant{}); break; } } void sdpa_vector_1pass_fallback( const Stream& s, cu::CommandEncoder& encoder, const array& q, const array& k, const array& v, const float scale, array& o, bool do_causal, const std::optional& sinks) { encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); if (sinks) { encoder.set_input_array(*sinks); } encoder.set_output_array(o); cu::AttnParams params{ /* int B = */ q.shape(0), /* int H = */ q.shape(1), /* int D = */ q.shape(3), /* int qL = */ q.shape(2), /* int kL = */ k.shape(2), /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; dim3 grid_dim(params.H, params.qL, params.B); dim3 block_dim(1024, 1, 1); dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; auto kernel = cu::kernel_sdpav_1pass; encoder.add_kernel_node( kernel, grid_dim, block_dim, gpu_ptr(q), gpu_ptr(k), gpu_ptr(v), gpu_ptr(o), sinks ? gpu_ptr(*sinks) : nullptr, params); }); }); }); } void sdpa_vector_2pass_fallback( const Stream& s, cu::CommandEncoder& encoder, const array& q, const array& k, const array& v, const float scale, array& o, bool do_causal, const std::optional& sinks) { cu::AttnParams params{ /* int B = */ q.shape(0), /* int H = */ q.shape(1), /* int D = */ q.shape(3), /* int qL = */ q.shape(2), /* int kL = */ k.shape(2), /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; // Allocate the intermediates int blocks = 32; Shape intermediate_shape; intermediate_shape.reserve(o.ndim() + 1); intermediate_shape.insert( intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1); intermediate_shape.push_back(blocks); intermediate_shape.push_back(o.shape().back()); array intermediate(intermediate_shape, float32, nullptr, {}); intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder)); sums.set_data(cu::malloc_async(sums.nbytes(), encoder)); maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder)); encoder.add_temporary(intermediate); encoder.add_temporary(sums); encoder.add_temporary(maxs); dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; { auto kernel = cu:: kernel_sdpav_2pass_1; encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); if (sinks) { encoder.set_input_array(*sinks); } encoder.set_output_array(intermediate); encoder.set_output_array(sums); encoder.set_output_array(maxs); dim3 grid_dim(params.H, params.qL, params.B * 32); dim3 block_dim(8 * 32, 1, 1); encoder.add_kernel_node( kernel, grid_dim, block_dim, gpu_ptr(q), gpu_ptr(k), gpu_ptr(v), sinks ? gpu_ptr(*sinks) : nullptr, gpu_ptr(intermediate), gpu_ptr(sums), gpu_ptr(maxs), params); } { auto kernel = cu:: kernel_sdpav_2pass_2; encoder.set_input_array(intermediate); encoder.set_input_array(sums); encoder.set_input_array(maxs); encoder.set_output_array(o); dim3 grid_dim(params.H, params.qL, params.B); dim3 block_dim(1024, 1, 1); encoder.add_kernel_node( kernel, grid_dim, block_dim, gpu_ptr(intermediate), gpu_ptr(sums), gpu_ptr(maxs), gpu_ptr(o), params); } }); }); }); } void sdpa_vector_fallback( const Stream& s, cu::CommandEncoder& encoder, const array& q, const array& k, const array& v, const float scale, array& o, bool do_causal, const std::optional& sinks) { int kL = k.shape(2); if (kL > 1024) { return sdpa_vector_2pass_fallback( s, encoder, q, k, v, scale, o, do_causal, sinks); } else { return sdpa_vector_1pass_fallback( s, encoder, q, k, v, scale, o, do_causal, sinks); } } } // namespace bool supports_sdpa_vector( const array& q, const array& k, const array& v, bool has_arr_mask, bool output_logsumexp) { if (output_logsumexp) { return false; } const int value_head_dim = v.shape(-1); const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); const int key_sequence_length = k.shape(2); const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); const bool supported_vector_config = sdpa_supported_head_dim && query_sequence_length < 4; return supported_vector_config && !has_arr_mask; } void sdpa_vector( const array& q_pre, const array& k_pre, const array& v_pre, float scale, array& o, bool do_causal, const std::optional& sinks_pre, Stream s) { auto& encoder = cu::get_command_encoder(s); std::vector copies; // Define some copy functions to ensure the layout of the inputs is as // expected. copies.reserve(4); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { array arr_copy = contiguous_copy_gpu(arr, s); copies.push_back(std::move(arr_copy)); return copies.back(); } else { return arr; } }; // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; }; std::optional sinks = std::nullopt; if (sinks_pre) { sinks = copy_unless(is_matrix_contiguous, sinks_pre.value()); } // We are in vector mode ie single query if (q_pre.shape(2) < 4) { auto q_copy_unless = [](const array& arr) { if (arr.flags().row_contiguous) { return true; } auto& strides = arr.strides(); auto& shape = arr.shape(); if (shape[0] == 1 || shape[1] == 1) { // If either the batch or head dimension is a singleton, the other can // be transposed with the sequence dimension auto bidx = shape[0] == 1 ? 1 : 0; return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && (strides[bidx] == shape[3]); } return false; }; auto kv_copy_unless = [](const array& arr) { // keys and values should be copied if: // - the last dimension is not contiguous // - the batch and head dim are not contiguous auto& strides = arr.strides(); auto& shape = arr.shape(); if (strides.back() != 1) { return false; } if (shape[0] == 1 || shape[1] == 1) { return true; } return (strides[0] == strides[1] * shape[1]); }; const auto& q = copy_unless(q_copy_unless, q_pre); const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { int64_t str_oD = 1; int64_t str_oH = o.shape(3); int64_t str_oL = o.shape(1) * str_oH; int64_t str_oB = o.shape(2) * str_oL; array::Flags flags{ /* bool contiguous = */ 1, /* bool row_contiguous = */ o.shape(2) == 1, /* bool col_contiguous = */ o.size() == o.shape(3), }; o.set_data( cu::malloc_async(o.nbytes(), encoder), o.size(), {str_oB, str_oH, str_oL, str_oD}, flags); } for (const auto& cp : copies) { encoder.add_temporary(cp); } sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks); } // Full attention mode should never reach here else { throw std::runtime_error("Doesn't support matrix yet."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/scan.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template struct ScanResult { using type = T; }; template <> struct ScanResult { using type = int32_t; }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::min(); } }; template inline __device__ void load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { int remaining = size - index * N_READS; if constexpr (reverse) { in += remaining - N_READS; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { values[N_READS - i - 1] = (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; } } else { for (int i = 0; i < N_READS; ++i) { values[N_READS - i - 1] = cast_to(in[i]); } } } else { in += index * N_READS; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { values[i] = (i < remaining) ? cast_to(in[i]) : init; } } else { for (int i = 0; i < N_READS; ++i) { values[i] = cast_to(in[i]); } } } } template inline __device__ void store_values(int index, T* out, T (&values)[N_READS], int size) { int start = index * N_READS + offset; int remaining = size - start; if constexpr (reverse) { out += remaining - N_READS; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { if (N_READS - i - 1 < remaining) { out[i] = values[N_READS - i - 1]; } } } else { for (int i = 0; i < N_READS; ++i) { out[i] = values[N_READS - i - 1]; } } } else { out += start; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { if (i < remaining) { out[i] = values[i]; } } } else { for (int i = 0; i < N_READS; ++i) { out[i] = values[i]; } } } } template < typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse> __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); in += grid.block_rank() * axis_size; out += grid.block_rank() * axis_size; __shared__ U warp_sums[WARP_SIZE]; Op op; U init = ReduceInit::value(); U prefix = init; // Scan per block. for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { int32_t index = r * block.size() + block.thread_rank(); U values[N_READS]; load_values(index, in, values, axis_size, init); // Compute an inclusive scan per thread. for (int i = 1; i < N_READS; ++i) { values[i] = op(values[i], values[i - 1]); } // Compute exclusive scan of thread sums. U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op); if (warp.thread_rank() == 0) { prev_thread_sum = init; } // Write wrap's sum to shared memory. if (warp.thread_rank() == WARP_SIZE - 1) { warp_sums[warp.meta_group_rank()] = op(prev_thread_sum, values[N_READS - 1]); } block.sync(); // Compute exclusive scan of warp sums. if (warp.meta_group_rank() == 0) { U prev_warp_sum = cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op); if (warp.thread_rank() == 0) { prev_warp_sum = init; } warp_sums[warp.thread_rank()] = prev_warp_sum; } block.sync(); // Compute the output. for (int i = 0; i < N_READS; ++i) { values[i] = op(values[i], prefix); values[i] = op(values[i], warp_sums[warp.meta_group_rank()]); values[i] = op(values[i], prev_thread_sum); } // Write the values. if (inclusive) { store_values(index, out, values, axis_size); } else { store_values(index, out, values, axis_size); if (reverse) { if (block.thread_rank() == 0 && index == 0) { out[axis_size - 1] = init; } } else { if (block.thread_rank() == 0 && index == 0) { out[0] = init; } } } block.sync(); // Share the prefix. if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && (warp.thread_rank() == WARP_SIZE - 1)) { warp_sums[0] = values[N_READS - 1]; } block.sync(); prefix = warp_sums[0]; } } template < typename T, typename U, typename Op, int N_READS, int BM, int BN, bool inclusive, bool reverse> __global__ void strided_scan( const T* in, U* out, int32_t axis_size, int64_t stride, int64_t stride_blocks) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); constexpr int n_warps = BN / N_READS; constexpr int n_scans = BN / n_warps; __shared__ U read_buffer[BM * BN_pad]; Op op; U init = ReduceInit::value(); U values[n_scans]; U prefix[n_scans]; for (int i = 0; i < n_scans; ++i) { prefix[i] = init; } // Compute offsets. int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride; int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN; uint32_t read_offset_y = (block.thread_rank() * N_READS) / BN; uint32_t read_offset_x = (block.thread_rank() * N_READS) % BN; uint32_t scan_offset_y = warp.thread_rank(); uint32_t scan_offset_x = warp.meta_group_rank() * n_scans; uint32_t stride_limit = stride - global_index_x; in += offset + global_index_x + read_offset_x; out += offset + global_index_x + read_offset_x; U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; for (uint32_t j = 0; j < axis_size; j += BM) { // Calculate the indices for the current thread. uint32_t index_y = j + read_offset_y; uint32_t check_index_y = index_y; if (reverse) { index_y = axis_size - 1 - index_y; } // Read in SM. if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; ++i) { read_into[i] = in[index_y * stride + i]; } } else { for (int i = 0; i < N_READS; ++i) { if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { read_into[i] = in[index_y * stride + i]; } else { read_into[i] = init; } } } block.sync(); // Read strided into registers. for (int i = 0; i < n_scans; ++i) { values[i] = read_from[i]; } // Perform the scan. for (int i = 0; i < n_scans; ++i) { values[i] = cg::inclusive_scan(warp, values[i], op); values[i] = op(values[i], prefix[i]); prefix[i] = warp.shfl(values[i], WARP_SIZE - 1); } // Write to SM. for (int i = 0; i < n_scans; ++i) { read_from[i] = values[i]; } block.sync(); // Write to device memory. if (!inclusive) { if (check_index_y == 0) { if ((read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; ++i) { out[index_y * stride + i] = init; } } else { for (int i = 0; i < N_READS; ++i) { if ((read_offset_x + i) < stride_limit) { out[index_y * stride + i] = init; } } } } if (reverse) { index_y -= 1; check_index_y += 1; } else { index_y += 1; check_index_y += 1; } } if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; ++i) { out[index_y * stride + i] = read_into[i]; } } else { for (int i = 0; i < N_READS; ++i) { if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { out[index_y * stride + i] = read_into[i]; } } } } } } // namespace cu template void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { if (scan_op == Scan::ReduceType::Max) { f(type_identity{}); } else if (scan_op == Scan::ReduceType::Min) { f(type_identity{}); } else if (scan_op == Scan::ReduceType::Sum) { f(type_identity{}); } else if (scan_op == Scan::ReduceType::Prod) { f(type_identity{}); } else if (scan_op == Scan::ReduceType::LogAddExp) { f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } } template const char* op_to_string() { if (cuda::std::is_same_v) { return "Max"; } else if (cuda::std::is_same_v) { return "Min"; } else if (cuda::std::is_same_v) { return "Sum"; } else if (cuda::std::is_same_v) { return "Prod"; } else if (cuda::std::is_same_v) { return "LogAddExp"; } else { throw std::invalid_argument("Unknown op."); } } template constexpr bool supports_scan_op() { if constexpr (cuda::std::is_same_v) { return is_inexact_v; } else { return true; } } void scan_gpu_inplace( array in, array& out, Scan::ReduceType reduce_type, int axis, bool reverse, bool inclusive, const Stream& s) { auto& encoder = cu::get_command_encoder(s); constexpr int N_READS = 4; int32_t axis_size = in.shape(axis); bool contiguous = in.strides()[axis] == 1; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using T = cuda_type_t; dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) { using Op = MLX_GET_TYPE(scan_op_tag); if constexpr (supports_scan_op()) { using U = typename cu::ScanResult::type; dispatch_bool(inclusive, [&](auto inclusive_tag) { dispatch_bool(reverse, [&](auto reverse_tag) { if (contiguous) { auto kernel = cu::contiguous_scan< T, U, Op, N_READS, inclusive_tag.value, reverse_tag.value>; int block_dim = cuda::ceil_div(axis_size, N_READS); block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); encoder.add_kernel_node( kernel, in.data_size() / axis_size, block_dim, gpu_ptr(in), gpu_ptr(out), axis_size); } else { constexpr int BM = WARP_SIZE; constexpr int BN = WARP_SIZE; auto kernel = cu::strided_scan< T, U, Op, N_READS, BM, BN, inclusive_tag.value, reverse_tag.value>; int64_t stride = in.strides()[axis]; int64_t stride_blocks = cuda::ceil_div(stride, BN); dim3 num_blocks = get_2d_grid_dims( in.shape(), in.strides(), axis_size * stride); if (num_blocks.x * stride_blocks <= UINT32_MAX) { num_blocks.x *= stride_blocks; } else { num_blocks.y *= stride_blocks; } int block_dim = (BN / N_READS) * WARP_SIZE; encoder.add_kernel_node( kernel, num_blocks, block_dim, gpu_ptr(in), gpu_ptr(out), axis_size, stride, stride_blocks); } }); }); } else { throw std::runtime_error( fmt::format( "Can not do scan op {} on inputs of {} with result of {}.", op_to_string(), dtype_to_string(in.dtype()), dtype_to_string(out.dtype()))); } }); }); } void Scan::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Scan::eval_gpu"); assert(inputs.size() == 1); auto in = inputs[0]; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); if (in.flags().contiguous && in.strides()[axis_] != 0) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); } else { out.set_data( cu::malloc_async(in.data_size() * out.itemsize(), encoder), in.data_size(), in.strides(), in.flags()); } } else { in = contiguous_copy_gpu(in, s); out.copy_shared_buffer(in); } scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/slicing.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/slicing.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/dtype_utils.h" #include namespace mlx::core { void concatenate_gpu( const std::vector& inputs, array& out, int axis, const Stream& s) { std::vector sizes; sizes.push_back(0); for (auto& p : inputs) { sizes.push_back(p.shape(axis)); } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); auto strides = out.strides(); auto flags = out.flags(); flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; auto concurrent = encoder.concurrent_context(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); } } array compute_dynamic_offset( const array& indices, const Strides& strides, const std::vector& axes, const Stream& s) { Dtype dtype = indices.dtype(); int nidx = axes.size(); std::string module_name = fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx); std::string kernel_name = fmt::format( "mlx::core::cu::compute_dynamic_offset<{}, {}>", dtype_to_cuda_type(dtype), nidx); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::string source = R"( #include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { template __global__ void compute_dynamic_offset( const T* indices, int64_t* offset, const __grid_constant__ Strides strides, const __grid_constant__ cuda::std::array axes) { int64_t acc = 0; #pragma unroll for (int i = 0; i < NIDX; ++i) { acc += indices[i] * strides[axes[i]]; } *offset = acc; } } // namespace mlx::core::cu )"; return std::make_tuple(false, std::move(source), std::vector{kernel_name}); }); auto& encoder = cu::get_command_encoder(s); // Prepare output. array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() && (indices.data_size() * indices.itemsize()) >= offset.itemsize(); if (donate) { offset.copy_shared_buffer(indices); } else { offset.set_data(cu::malloc_async(offset.itemsize(), encoder)); } encoder.add_temporary(offset); encoder.set_input_array(indices); encoder.set_output_array(offset); cu::KernelArgs args; args.append(indices); args.append(offset); args.append_ndim(strides); args.append(axes); auto kernel = mod.get_kernel(kernel_name); encoder.add_kernel_node_raw(kernel, 1, 1, {}, 0, args.args()); return offset; } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/softmax.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template inline __device__ T softmax_exp(T x) { // Softmax doesn't need high precision exponential cause x is gonna be in // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). return __expf(x); } template __global__ void softmax(const T* in, T* out, int axis_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); in += grid.block_rank() * axis_size; out += grid.block_rank() * axis_size; cg::greater max_op; cg::plus plus_op; // Thread reduce. AccT prevmax; AccT maxval = Limits::finite_min(); AccT normalizer = cast_to(0); for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { auto index = r * BLOCK_DIM + block.thread_rank(); auto vals = load_vector(in, index, axis_size, Limits::min()); prevmax = maxval; #pragma unroll for (int i = 0; i < N_READS; ++i) { maxval = max_op(maxval, static_cast(vals[i])); } // Online normalizer calculation for softmax: // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); #pragma unroll for (int i = 0; i < N_READS; i++) { normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } // First warp reduce. prevmax = maxval; maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = cg::reduce(warp, normalizer, plus_op); __shared__ AccT local_max[WARP_SIZE]; __shared__ AccT local_normalizer[WARP_SIZE]; // Write to shared memory and do second warp reduce. prevmax = maxval; if (warp.thread_rank() == 0) { local_max[warp.meta_group_rank()] = maxval; } block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] : Limits::min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { local_normalizer[warp.meta_group_rank()] = normalizer; } block.sync(); normalizer = warp.thread_rank() < warp.meta_group_size() ? local_normalizer[warp.thread_rank()] : AccT{}; normalizer = cg::reduce(warp, normalizer, plus_op); normalizer = 1 / normalizer; // Write output. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { auto index = r * BLOCK_DIM + block.thread_rank(); auto vals = load_vector(in, index, axis_size, T(0)); for (int i = 0; i < N_READS; i++) { vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; } store_vector(out, index, vals, axis_size); } } } // namespace cu void Softmax::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Softmax::eval_gpu"); assert(inputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. auto set_output = [&s, &out, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( cu::malloc_async(x.data_size() * x.itemsize(), encoder), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; array in = set_output(inputs[0]); bool precise = in.dtype() != float32 && precise_; int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::softmax; if (precise) { kernel = cu::softmax; } encoder.add_kernel_node( kernel, n_rows, block_dim(), gpu_ptr(in), gpu_ptr(out), axis_size); }); }); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/sort.cu ================================================ // Copyright © 2025 Apple Inc. #include #include #include #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include namespace mlx::core { constexpr int N_PER_THREAD = 8; namespace cu { template __device__ __forceinline__ T nan_value(); template <> __device__ __forceinline__ float nan_value() { return cuda::std::numeric_limits::quiet_NaN(); } template <> __device__ __forceinline__ double nan_value() { return cuda::std::numeric_limits::quiet_NaN(); } template <> __device__ __forceinline__ __half nan_value<__half>() { return __float2half(cuda::std::numeric_limits::quiet_NaN()); } template <> __device__ __forceinline__ __nv_bfloat16 nan_value<__nv_bfloat16>() { return __float2bfloat16(cuda::std::numeric_limits::quiet_NaN()); } template struct InitValue { __device__ __forceinline__ static T value() { return Limits::max(); } }; template struct InitValue>> { __device__ __forceinline__ static T value() { return nan_value(); } }; template __device__ __forceinline__ void thread_swap(T& a, T& b) { T w = a; a = b; b = w; } template struct LessThan { __device__ __forceinline__ static T init() { return InitValue::value(); } __device__ __forceinline__ bool operator()(T a, T b) const { if constexpr (is_floating_v) { bool an = cuda::std::isnan(a); bool bn = cuda::std::isnan(b); if (an | bn) { return (!an) & bn; } } return a < b; } }; template < typename ValT, typename IdxT, bool ARG_SORT, int N_PER_THREAD, typename CompareOp> struct ThreadSort { __device__ __forceinline__ static void sort( ValT (&vals)[N_PER_THREAD], IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { #pragma unroll for (int j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); if constexpr (ARG_SORT) { thread_swap(idxs[j + 1], idxs[j]); } } } } } }; template < typename ValT, typename IdxT, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD, typename CompareOp> struct BlockMergeSort { using thread_sort_t = ThreadSort; __device__ __forceinline__ static int merge_partition( const ValT* As, const ValT* Bs, int A_sz, int B_sz, int sort_md) { CompareOp op; int A_st = max(0, sort_md - B_sz); int A_ed = min(sort_md, A_sz); while (A_st < A_ed) { int md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; if (op(b, a)) { A_ed = md; } else { A_st = md + 1; } } return A_ed; } __device__ __forceinline__ static void merge_step( const ValT* As, const ValT* Bs, const IdxT* As_idx, const IdxT* Bs_idx, int A_sz, int B_sz, ValT (&vals)[N_PER_THREAD], IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; int a_idx = 0; int b_idx = 0; #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init()); auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init()); bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; if constexpr (ARG_SORT) { if (pred) { idxs[i] = Bs_idx[b_idx]; } else { idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0); } } b_idx += int(pred); a_idx += int(!pred); } } __device__ __forceinline__ static void sort(ValT* tgp_vals, IdxT* tgp_idxs, int size_sorted_axis) { int idx = threadIdx.x * N_PER_THREAD; ValT thread_vals[N_PER_THREAD]; IdxT thread_idxs[N_PER_THREAD]; #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { thread_vals[i] = tgp_vals[idx + i]; if constexpr (ARG_SORT) { thread_idxs[i] = tgp_idxs[idx + i]; } } if (idx < size_sorted_axis) { thread_sort_t::sort(thread_vals, thread_idxs); } for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) { __syncthreads(); #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; if constexpr (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } __syncthreads(); int merge_group = threadIdx.x / merge_threads; int merge_lane = threadIdx.x % merge_threads; int sort_sz = N_PER_THREAD * merge_threads; int sort_st = N_PER_THREAD * merge_threads * merge_group; int A_st = sort_st; int A_ed = sort_st + sort_sz / 2; int B_st = sort_st + sort_sz / 2; int B_ed = sort_st + sort_sz; const ValT* As = tgp_vals + A_st; const ValT* Bs = tgp_vals + B_st; int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; int sort_md = N_PER_THREAD * merge_lane; int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); As += partition; Bs += sort_md - partition; A_sz -= partition; B_sz -= sort_md - partition; const IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; const IdxT* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); } __syncthreads(); #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; if constexpr (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } } }; template < typename T, typename U, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD, typename CompareOp = LessThan> struct KernelMergeSort { using ValT = T; using IdxT = uint32_t; using block_merge_sort_t = BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; __device__ __forceinline__ static void block_sort( const T* inp, U* out, int size_sorted_axis, int64_t in_stride_sorted_axis, int64_t out_stride_sorted_axis, int64_t in_stride_segment_axis, int64_t out_stride_segment_axis, ValT* tgp_vals, IdxT* tgp_idxs) { inp += blockIdx.y * in_stride_segment_axis; out += blockIdx.y * out_stride_segment_axis; for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] : ValT(CompareOp::init()); if constexpr (ARG_SORT) { tgp_idxs[i] = i; } } __syncthreads(); block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis); __syncthreads(); for (int i = threadIdx.x; i < size_sorted_axis; i += BLOCK_THREADS) { if constexpr (ARG_SORT) { out[i * out_stride_sorted_axis] = tgp_idxs[i]; } else { out[i * out_stride_sorted_axis] = tgp_vals[i]; } } } }; template < typename T, typename U, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD> __global__ void block_sort_kernel( const T* inp, U* out, int size_sorted_axis, int64_t in_stride_sorted_axis, int64_t out_stride_sorted_axis, int64_t in_stride_segment_axis, int64_t out_stride_segment_axis) { using sort_kernel = KernelMergeSort; using ValT = typename sort_kernel::ValT; using IdxT = typename sort_kernel::IdxT; if constexpr (ARG_SORT) { __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis, tgp_vals, tgp_idxs); } else { __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis, tgp_vals, nullptr); } } template < typename T, typename U, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD> __global__ void block_sort_nc_kernel( const T* inp, U* out, int size_sorted_axis, int64_t in_stride_sorted_axis, int64_t out_stride_sorted_axis, const __grid_constant__ Shape nc_shape, const __grid_constant__ Strides in_nc_strides, const __grid_constant__ Strides out_nc_strides, int nc_dim) { using sort_kernel = KernelMergeSort; using ValT = typename sort_kernel::ValT; using IdxT = typename sort_kernel::IdxT; int64_t in_block_idx = elem_to_loc( int64_t(blockIdx.y), nc_shape.data(), in_nc_strides.data(), nc_dim); int64_t out_block_idx = elem_to_loc( int64_t(blockIdx.y), nc_shape.data(), out_nc_strides.data(), nc_dim); inp += in_block_idx; out += out_block_idx; if constexpr (ARG_SORT) { __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, 0, 0, tgp_vals, tgp_idxs); } else { __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, 0, 0, tgp_vals, nullptr); } } template < typename ValT, typename IdxT, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD, typename CompareOp = LessThan> struct KernelMultiBlockMergeSort { using block_merge_sort_t = BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; __device__ __forceinline__ static void block_sort( const ValT* inp, ValT* out_vals, IdxT* out_idxs, int size_sorted_axis, int64_t stride_sorted_axis, ValT* tgp_vals, IdxT* tgp_idxs) { int base_idx = blockIdx.x * N_PER_BLOCK; for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : ValT(CompareOp::init()); tgp_idxs[i] = idx; } __syncthreads(); block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis); __syncthreads(); for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; if (idx < size_sorted_axis) { out_vals[idx] = tgp_vals[i]; out_idxs[idx] = tgp_idxs[i]; } } } __device__ __forceinline__ static int merge_partition( const ValT* As, const ValT* Bs, int A_sz, int B_sz, int sort_md) { CompareOp op; int A_st = max(0, sort_md - B_sz); int A_ed = min(sort_md, A_sz); while (A_st < A_ed) { int md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; if (op(b, a)) { A_ed = md; } else { A_st = md + 1; } } return A_ed; } }; template < typename ValT, typename IdxT, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD> __global__ void mb_block_sort_kernel( const ValT* inp, ValT* out_vals, IdxT* out_idxs, int size_sorted_axis, int64_t stride_sorted_axis, const __grid_constant__ Shape nc_shape, const __grid_constant__ Strides nc_strides, int nc_dim) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; int64_t block_idx = elem_to_loc( int64_t(blockIdx.y), nc_shape.data(), nc_strides.data(), nc_dim); inp += block_idx; out_vals += blockIdx.y * size_sorted_axis; out_idxs += blockIdx.y * size_sorted_axis; __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out_vals, out_idxs, size_sorted_axis, stride_sorted_axis, tgp_vals, tgp_idxs); } template < typename ValT, typename IdxT, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD> __global__ void mb_block_partition_kernel( IdxT* block_partitions, const ValT* dev_vals, const IdxT* dev_idxs, int size_sorted_axis, int merge_tiles, int n_blocks) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; (void)dev_idxs; block_partitions += blockIdx.y * blockDim.x; dev_vals += blockIdx.y * size_sorted_axis; dev_idxs += blockIdx.y * size_sorted_axis; for (int i = threadIdx.x; i <= n_blocks; i += blockDim.x) { int merge_group = i / merge_tiles; int merge_lane = i % merge_tiles; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; int A_st = min(size_sorted_axis, sort_st); int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); int B_st = A_ed; int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); int partition = sort_kernel::merge_partition( dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); block_partitions[i] = A_st + partition; } } template < typename ValT, typename IdxT, bool ARG_SORT, int BLOCK_THREADS, int N_PER_THREAD, typename CompareOp = LessThan> __global__ void mb_block_merge_kernel( const IdxT* block_partitions, const ValT* dev_vals_in, const IdxT* dev_idxs_in, ValT* dev_vals_out, IdxT* dev_idxs_out, int size_sorted_axis, int merge_tiles, int num_tiles) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; using block_sort_t = typename sort_kernel::block_merge_sort_t; block_partitions += blockIdx.y * (num_tiles + 1); dev_vals_in += blockIdx.y * size_sorted_axis; dev_idxs_in += blockIdx.y * size_sorted_axis; dev_vals_out += blockIdx.y * size_sorted_axis; dev_idxs_out += blockIdx.y * size_sorted_axis; int block_idx = blockIdx.x; int merge_group = block_idx / merge_tiles; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; int A_st = block_partitions[block_idx + 0]; int A_ed = block_partitions[block_idx + 1]; int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); int B_ed = min( size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); if ((block_idx % merge_tiles) == merge_tiles - 1) { A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); B_ed = min(size_sorted_axis, sort_st + sort_sz); } int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; ValT thread_vals[N_PER_THREAD]; IdxT thread_idxs[N_PER_THREAD]; #pragma unroll for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + threadIdx.x; if (idx < (A_sz + B_sz)) { thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz]; thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz]; } else { thread_vals[i] = CompareOp::init(); thread_idxs[i] = 0; } } __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; __syncthreads(); #pragma unroll for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + threadIdx.x; tgp_vals[idx] = thread_vals[i]; tgp_idxs[idx] = thread_idxs[i]; } __syncthreads(); int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(threadIdx.x)); int A_st_local = block_sort_t::merge_partition( tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); int A_ed_local = A_sz; int B_st_local = sort_md_local - A_st_local; int B_ed_local = B_sz; int A_sz_local = A_ed_local - A_st_local; int B_sz_local = B_ed_local - B_st_local; block_sort_t::merge_step( tgp_vals + A_st_local, tgp_vals + A_ed_local + B_st_local, tgp_idxs + A_st_local, tgp_idxs + A_ed_local + B_st_local, A_sz_local, B_sz_local, thread_vals, thread_idxs); __syncthreads(); #pragma unroll for (int i = 0; i < N_PER_THREAD; ++i) { int idx = threadIdx.x * N_PER_THREAD; tgp_vals[idx + i] = thread_vals[i]; tgp_idxs[idx + i] = thread_idxs[i]; } __syncthreads(); int base_idx = blockIdx.x * sort_kernel::N_PER_BLOCK; for (int i = threadIdx.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; if (idx < size_sorted_axis) { dev_vals_out[idx] = tgp_vals[i]; dev_idxs_out[idx] = tgp_idxs[i]; } } } } // namespace cu namespace { void single_block_sort( const Stream& s, const array& in, array& out, int axis, int bn, bool argsort) { int n_rows = in.size() / in.shape(axis); auto in_nc_str = in.strides(); in_nc_str.erase(in_nc_str.begin() + axis); auto out_nc_str = out.strides(); out_nc_str.erase(out_nc_str.begin() + axis); auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); int size_sorted_axis = in.shape(axis); int64_t in_stride_sorted_axis = in.strides()[axis]; int64_t out_stride_sorted_axis = out.strides()[axis]; bool contiguous = in.flags().contiguous; auto check_strides = [](const array& x, int64_t sort_stride) { int64_t min_stride = *std::min_element(x.strides().begin(), x.strides().end()); int64_t max_stride = *std::max_element(x.strides().begin(), x.strides().end()); return sort_stride == min_stride || sort_stride == max_stride; }; contiguous &= check_strides(in, in_stride_sorted_axis); contiguous &= check_strides(out, out_stride_sorted_axis); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { using ValT = cuda_type_t; dispatch_block_dim(bn, [&](auto block_dim) { constexpr int BLOCK_THREADS = block_dim(); if constexpr (BLOCK_THREADS < 1024) { dim3 grid(1, n_rows, 1); dim3 block(BLOCK_THREADS, 1, 1); dispatch_bool(argsort, [&](auto arg_tag) { constexpr bool ARG_SORT = decltype(arg_tag)::value; using OutT = std::conditional_t; if (contiguous) { auto kernel = cu::block_sort_kernel< ValT, OutT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; int64_t in_stride_segment_axis = INT64_MAX; int64_t out_stride_segment_axis = INT64_MAX; for (int i = 0; i < nc_shape.size(); i++) { if (nc_shape[i] == 1) { continue; } if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { throw std::runtime_error( "[Sort::eval_gpu] Stride too large."); } in_stride_segment_axis = std::min(in_stride_segment_axis, in_nc_str[i]); out_stride_segment_axis = std::min(out_stride_segment_axis, out_nc_str[i]); } encoder.add_kernel_node( kernel, grid, block, gpu_ptr(in), gpu_ptr(out), size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis); } else { auto kernel = cu::block_sort_nc_kernel< ValT, OutT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; auto nc_shape_param = const_param(nc_shape); auto in_nc_strides_param = const_param(in_nc_str); auto out_nc_strides_param = const_param(out_nc_str); encoder.add_kernel_node( kernel, grid, block, gpu_ptr(in), gpu_ptr(out), size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, nc_shape_param, in_nc_strides_param, out_nc_strides_param, nc_dim); } }); } }); } else { throw std::runtime_error( "CUDA backend does not support sorting complex numbers"); } }); } void multi_block_sort( const Stream& s, const array& in, array& out, int axis, int n_blocks, bool argsort) { int n_rows = in.size() / in.shape(axis); auto nc_str = in.strides(); nc_str.erase(nc_str.begin() + axis); auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); if (nc_dim == 0) { nc_shape = {0}; nc_str = {1}; } int size_sorted_axis = in.shape(axis); int64_t stride_sorted_axis = in.strides()[axis]; array dev_vals_in({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); array dev_vals_out({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); array dev_idxs_in({n_rows, size_sorted_axis}, uint32, nullptr, {}); array dev_idxs_out({n_rows, size_sorted_axis}, uint32, nullptr, {}); array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {}); auto& encoder = cu::get_command_encoder(s); dev_vals_in.set_data(cu::malloc_async(dev_vals_in.nbytes(), encoder)); dev_vals_out.set_data(cu::malloc_async(dev_vals_out.nbytes(), encoder)); dev_idxs_in.set_data(cu::malloc_async(dev_idxs_in.nbytes(), encoder)); dev_idxs_out.set_data(cu::malloc_async(dev_idxs_out.nbytes(), encoder)); block_partitions.set_data( cu::malloc_async(block_partitions.nbytes(), encoder)); encoder.add_temporary(block_partitions); dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { using ValT = cuda_type_t; using IdxT = uint32_t; constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512; dim3 grid(n_blocks, n_rows, 1); dim3 block(BLOCK_THREADS, 1, 1); dispatch_bool(argsort, [&](auto arg_tag) { constexpr bool ARG_SORT = decltype(arg_tag)::value; auto nc_shape_param = const_param(nc_shape); auto nc_strides_param = const_param(nc_str); auto block_sort_kernel = cu::mb_block_sort_kernel< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; encoder.set_input_array(in); encoder.set_output_array(dev_vals_in); encoder.set_output_array(dev_idxs_in); encoder.add_kernel_node( block_sort_kernel, grid, block, gpu_ptr(in), gpu_ptr(dev_vals_in), gpu_ptr(dev_idxs_in), size_sorted_axis, stride_sorted_axis, nc_shape_param, nc_strides_param, nc_dim); int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024; for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) { auto partition_kernel = cu::mb_block_partition_kernel< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; encoder.set_input_array(dev_vals_in); encoder.set_input_array(dev_idxs_in); encoder.set_output_array(block_partitions); encoder.add_kernel_node( partition_kernel, dim3(1, n_rows, 1), dim3(n_thr_per_group, 1, 1), gpu_ptr(block_partitions), gpu_ptr(dev_vals_in), gpu_ptr(dev_idxs_in), size_sorted_axis, merge_tiles, n_blocks); auto merge_kernel = cu::mb_block_merge_kernel< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; encoder.set_input_array(dev_vals_in); encoder.set_input_array(dev_idxs_in); encoder.set_input_array(block_partitions); encoder.set_output_array(dev_vals_out); encoder.set_output_array(dev_idxs_out); encoder.add_kernel_node( merge_kernel, dim3(n_blocks, n_rows, 1), dim3(BLOCK_THREADS, 1, 1), gpu_ptr(block_partitions), gpu_ptr(dev_vals_in), gpu_ptr(dev_idxs_in), gpu_ptr(dev_vals_out), gpu_ptr(dev_idxs_out), size_sorted_axis, merge_tiles, n_blocks); std::swap(dev_vals_in, dev_vals_out); std::swap(dev_idxs_in, dev_idxs_out); } }); } else { throw std::runtime_error( "CUDA backend does not support sorting complex numbers"); } }); encoder.add_temporary(dev_vals_out); encoder.add_temporary(dev_idxs_out); encoder.add_temporary(argsort ? dev_vals_in : dev_idxs_in); if (axis == in.ndim() - 1) { // Copy buffer to out, no need for temporary out.copy_shared_buffer( argsort ? dev_idxs_in : dev_vals_in, out.strides(), out.flags(), out.size()); } else { encoder.add_temporary(argsort ? dev_idxs_in : dev_vals_in); out.set_data(cu::malloc_async(out.nbytes(), encoder)); auto strides = out.strides(); for (int ax = axis + 1; ax < strides.size(); ax++) { strides[ax] *= out.shape(axis); } strides[axis] = 1; copy_gpu_inplace( (argsort) ? dev_idxs_in : dev_vals_in, out, out.shape(), strides, out.strides(), 0, 0, CopyType::General, s); } } void gpu_merge_sort( const Stream& s, const array& in, array& out, int axis_, bool argsort) { int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); constexpr int tn = N_PER_THREAD; int potential_bn = (size_sorted_axis + tn - 1) / tn; int bn; if (potential_bn > 256) { bn = 512; } else if (potential_bn > 128) { bn = 256; } else if (potential_bn > 64) { bn = 128; } else if (potential_bn > 32) { bn = 64; } else { bn = 32; } if (bn == 512 && size_of(in.dtype()) > 4) { bn = 256; } int n_per_block = bn * tn; int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block; if (n_blocks > 1) { return multi_block_sort(s, in, out, axis, n_blocks, argsort); } return single_block_sort(s, in, out, axis, bn, argsort); } void gpu_sort( const Stream& s, const array& in, array& out, int axis, bool argsort) { auto& encoder = cu::get_command_encoder(s); gpu_merge_sort(s, in, out, axis, argsort); } } // namespace void ArgSort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgSort::eval_gpu"); assert(inputs.size() == 1); gpu_sort(stream(), inputs[0], out, axis_, true); } void Sort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Sort::eval_gpu"); assert(inputs.size() == 1); gpu_sort(stream(), inputs[0], out, axis_, false); } void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); gpu_sort(stream(), inputs[0], out, axis_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Partition::eval_gpu"); gpu_sort(stream(), inputs[0], out, axis_, false); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/steel/defines.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #define MLX_UNROLL _Pragma("unroll") #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #define MLX_CUDA_SM_80_ENABLED #endif ================================================ FILE: mlx/backend/cuda/steel/gemm.cuh ================================================ #include "mlx/backend/cuda/steel/mma.cuh" #include "mlx/backend/cuda/steel/tiles.cuh" namespace mlx::core::cu { /** * An example gemm written with the utils. * * Computes A @ B.T when A and B are all aligned with the block sizes. */ template __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { constexpr int WARPS_M = 2; constexpr int WARPS_N = 2; constexpr int NUM_WARPS = WARPS_M * WARPS_N; constexpr int WARP_STEP_M = BM / WARPS_M; constexpr int WARP_STEP_N = BN / WARPS_N; // Precompute some offsets for each thread const int warpid = threadIdx.x / 32; const int laneid = threadIdx.x % 32; const int wm = warpid / WARPS_N; const int wn = warpid % WARPS_N; const int offset_m = wm * WARP_STEP_M; const int offset_n = wn * WARP_STEP_N; // Allocate shared memory extern __shared__ char shmem[]; SharedTile(&as)[2] = *(SharedTile(*)[2])(&shmem[0]); SharedTile(&bs)[2] = *(SharedTile(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); // Allocate registers for the MMA RegisterTile C; RegisterTile A; RegisterTile B; // Move the global pointers to the tile a += blockIdx.y * BM * K; b += blockIdx.x * BN * K; y += blockIdx.y * BM * N + blockIdx.x * BN; // Zero the accumulators C.fill(0); // Start the SM pipeline load_async(as[0], as[0].base_addr(), a, K); load_async(bs[0], bs[0].base_addr(), b, K); cp_async_commit(); int tic = 0; for (int k_block = BK; k_block < K; k_block += BK) { load_async(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); cp_async_commit(); cp_async_wait<1>(); __syncthreads(); MLX_UNROLL for (int k = 0; k < BK / 16; k++) { A.load( as[tic], as[tic].base_addr(), offset_m + laneid % 16, k * 16 + laneid / 16 * 8); B.load( bs[tic], bs[tic].base_addr(), offset_n + laneid % 16, k * 16 + laneid / 16 * 8); mma_t(C, A, B); } tic ^= 1; } // Empty the pipeline cp_async_wait_all(); __syncthreads(); MLX_UNROLL for (int k = 0; k < BK / 16; k++) { A.load( as[tic], as[tic].base_addr(), offset_m + laneid % 16, k * 16 + laneid / 16 * 8); B.load( bs[tic], bs[tic].base_addr(), offset_n + laneid % 16, k * 16 + laneid / 16 * 8); mma_t(C, A, B); } C.store_global(y, N, offset_m, offset_n); } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/steel/mma.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/steel/defines.cuh" #include "mlx/backend/cuda/steel/tiles.cuh" namespace mlx::core::cu { /** * Fallback mma. * * We should probably a) implement a fallback or complain about it to the * compiler. */ template __device__ inline void mma_t(Tile16x16& C, Tile16x16& A, Tile16x16& B) {} /** * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16 * float tile. * * We actually perform C += A @ B.T */ __device__ __forceinline__ void mma_t( Tile16x16& C, Tile16x16<__nv_bfloat16>& A, Tile16x16<__nv_bfloat16>& B) { #if defined(MLX_CUDA_SM_80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}, " "{%4, %5, %6, %7}, " "{%8, %9}, " "{%10, %11, %12, %13};" // D matrix : "+f"(C.values[0].x), "+f"(C.values[0].y), "+f"(C.values[1].x), "+f"(C.values[1].y) // A matrix : "r"(*(uint32_t*)(&A.values[0])), "r"(*(uint32_t*)(&A.values[1])), "r"(*(uint32_t*)(&A.values[2])), "r"(*(uint32_t*)(&A.values[3])), // B matrix "r"(*(uint32_t*)(&B.values[0])), "r"(*(uint32_t*)(&B.values[2])), // C matrix "f"(C.values[0].x), "f"(C.values[0].y), "f"(C.values[1].x), "f"(C.values[1].y)); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}, " "{%4, %5, %6, %7}, " "{%8, %9}, " "{%10, %11, %12, %13};" // D matrix : "+f"(C.values[2].x), "+f"(C.values[2].y), "+f"(C.values[3].x), "+f"(C.values[3].y) // A matrix : "r"(*(uint32_t*)(&A.values[0])), "r"(*(uint32_t*)(&A.values[1])), "r"(*(uint32_t*)(&A.values[2])), "r"(*(uint32_t*)(&A.values[3])), // B matrix "r"(*(uint32_t*)(&B.values[1])), "r"(*(uint32_t*)(&B.values[3])), // C matrix "f"(C.values[2].x), "f"(C.values[2].y), "f"(C.values[3].x), "f"(C.values[3].y)); #endif } /** * Multiply larger register tiles by delegating to mma_t. */ template __device__ __forceinline__ void mma_t( RegisterTile& C, RegisterTile& A, RegisterTile& B) { constexpr int TILES_M = RegisterTile::TILES_Y; constexpr int TILES_K = RegisterTile::TILES_X; constexpr int TILES_N = RegisterTile::TILES_Y; MLX_UNROLL for (int k = 0; k < TILES_K; k++) { MLX_UNROLL for (int m = 0; m < TILES_M; m++) { MLX_UNROLL for (int n = 0; n < TILES_N; n++) { mma_t( C.data[m * TILES_N + n], A.data[m * TILES_K + k], B.data[n * TILES_K + k]); } } } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/steel/tiles.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/steel/utils.cuh" #include "mlx/backend/cuda/vector_types.cuh" namespace mlx::core::cu { /** * The basic building block for Ampere mmas. A 16x16 tile distributed across * the warp. * * Each thread holds 8 values. They are distributed according to * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float * * For use instructions see the individual methods eg load(). */ template struct Tile16x16 { using T2 = Vector2_t; T2 values[4]; __device__ inline void fill(T v) { T2 v2 = {v, v}; for (int i = 0; i < 4; i++) { values[i] = v2; } } /** * Load a 16x16 tile from shared memory. * * The instruction is a bit weird in the sense that the address provided by * each thread and the elements loaded are not the same. * * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a * result the warp provides 4*8 = 32 addresses one per row. * * Threads 0-7 provide the addresses for the first tile, 8-15 for the second * and so on. For instance to load a non swizzled tile we would do * * base_addr + (laneid % 16) * BK + (laneid / 2) * 8 * * See * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix */ __device__ __forceinline__ void load(uint32_t row_address) { if constexpr ( std::is_same_v || std::is_same_v) { asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(*(uint32_t*)&(values[0])), "=r"(*(uint32_t*)&(values[1])), "=r"(*(uint32_t*)&(values[2])), "=r"(*(uint32_t*)&(values[3])) : "r"(row_address)); } } /** * Store the tile to the address pointed to by `x`. * * The provided pointer is a generic pointer but this is meant to be used to * store to global memory. For storing to shared memory we should use * `stmatrix`. * * This also showcases the format of the tile quite nicely. Each register is * holding to adjacent values. The indices are * * row + 0, col + 0 * row + 8, col + 0 * row + 0, col + 8 * row + 8, col + 8 * * Given that we are dealing with Vector2_t the column offsets are 4 * instead of 8. */ template __device__ inline void store_global(U* x, int N) { using U2 = Vector2_t; U2* x2 = reinterpret_cast(x); const int laneid = threadIdx.x % 32; const int row = laneid / 4; const int col = laneid % 4; if constexpr (std::is_same_v) { x2[(row + 0) * (N / 2) + col + 0] = values[0]; x2[(row + 0) * (N / 2) + col + 4] = values[2]; x2[(row + 8) * (N / 2) + col + 0] = values[1]; x2[(row + 8) * (N / 2) + col + 4] = values[3]; } else if constexpr ( std::is_same_v && std::is_same_v) { x2[(row + 0) * (N / 2) + col + 0] = __floats2bfloat162_rn(values[0].x, values[0].y); x2[(row + 0) * (N / 2) + col + 4] = __floats2bfloat162_rn(values[2].x, values[2].y); x2[(row + 8) * (N / 2) + col + 0] = __floats2bfloat162_rn(values[1].x, values[1].y); x2[(row + 8) * (N / 2) + col + 4] = __floats2bfloat162_rn(values[3].x, values[3].y); } } template __device__ inline void store_global_safe(U* x, int N, int max_rows) { const int laneid = threadIdx.x % 32; const int row = laneid / 4; const int col = laneid % 4; if (row < max_rows) { x[(row + 0) * N + 2 * col + 0] = static_cast(values[0].x); x[(row + 0) * N + 2 * col + 1] = static_cast(values[0].y); x[(row + 0) * N + 2 * col + 8] = static_cast(values[2].x); x[(row + 0) * N + 2 * col + 9] = static_cast(values[2].y); } if (row + 8 < max_rows) { x[(row + 8) * N + 2 * col + 0] = static_cast(values[1].x); x[(row + 8) * N + 2 * col + 1] = static_cast(values[1].y); x[(row + 8) * N + 2 * col + 8] = static_cast(values[3].x); x[(row + 8) * N + 2 * col + 9] = static_cast(values[3].y); } } }; /** * A simple container of multiple Tile16x16. * * Provides utility functions for loading and manipulating collections of basic * tiles. */ template struct RegisterTile { static constexpr int ROWS = ROWS_; static constexpr int COLS = COLS_; static constexpr int TILES_X = COLS / 16; static constexpr int TILES_Y = ROWS / 16; Tile16x16 data[TILES_X * TILES_Y]; __device__ inline void fill(T v) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].fill(v); } } } template __device__ __forceinline__ void load(Tile& tile, uint32_t base_address, int row, int col) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].load( tile.loc(base_address, row + i * 16, col + j * 16)); } } } template __device__ __forceinline__ void load(Tile& tile, F f, uint32_t base_address, int row, int col) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { f(data[i * TILES_X + j], tile, base_address, row + i * 16, col + j * 16); } } } template __device__ inline void store_global(U* x, int N, int row, int col) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].store_global( x + (row + i * 16) * N + col + j * 16, N); } } } template __device__ inline void store_global_safe(U* x, int N, int row, int col, int max_rows) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].store_global_safe( x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16); } } } }; /** * A simple container of multiple Tile16x16. * * Provides utility functions for loading and manipulating collections of basic * tiles. */ template struct RegisterTile { static constexpr int ROWS = ROWS_; static constexpr int COLS = COLS_; static constexpr int TILES_X = COLS / 16; static constexpr int TILES_Y = ROWS / 16; Tile16x16 data[TILES_X * TILES_Y]; __device__ inline void fill(T v) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].fill(v); } } } template __device__ inline void load(Tile& tile, uint32_t base_address, int row, int col) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].load( tile.loc(base_address, row + i * 16, col + j * 16)); } } } template __device__ inline void store_global(U* x, int N, int row, int col) { MLX_UNROLL for (int i = 0; i < TILES_Y; i++) { MLX_UNROLL for (int j = 0; j < TILES_X; j++) { data[i * TILES_X + j].store_global( x + (row + i * 16) * N + col + j * 16, N); } } } }; template struct SharedTile { static constexpr int ROWS = ROWS_; static constexpr int COLS = COLS_; static constexpr int TILES_X = COLS / 16; static constexpr int TILES_Y = ROWS / 16; static constexpr int NUMEL = ROWS * COLS; // Swizzle taken from ThunderKittens. Should be changed when we switch to // cute Layouts. // // See inludes/types/shared/st.cuh // // I do feel that it is too math heavy and can be improved. Also the math is // done every time although the addresses don't change from load to load. I // guess we are expecting the compiler to figure that out. static constexpr int swizzle_bytes = (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32)) : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0)); T data[ROWS * COLS]; __device__ inline uint32_t base_addr() const { return __cvta_generic_to_shared(&data[0]); } // Return a pointer to the element at (row, col) using the swizzle. __device__ static inline T* ptr(T* ptr, int row, int col) { if constexpr (swizzle_bytes > 0) { static constexpr int swizzle_repeat = swizzle_bytes * 8; static constexpr int subtile_cols = swizzle_bytes / sizeof(T); const int outer_idx = col / subtile_cols; const uint64_t addr = (uint64_t)(&ptr [outer_idx * ROWS * subtile_cols + row * subtile_cols + col % subtile_cols]); const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; return (T*)(addr ^ swizzle); } else { return ptr + row * COLS + col; } } // Return the location of the element at (row, col) using the swizzle. __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { if constexpr (swizzle_bytes > 0) { static constexpr int swizzle_repeat = swizzle_bytes * 8; static constexpr int subtile_cols = swizzle_bytes / sizeof(T); const int outer_idx = col / subtile_cols; const uint32_t addr = ptr + sizeof(T) * (outer_idx * ROWS * subtile_cols + row * subtile_cols + col % subtile_cols); const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; return (addr ^ swizzle); } else { return ptr + sizeof(T) * (row * COLS + col); } } // Convenience functions to edit elements going through the swizzle. __device__ inline T& operator()(int row, int col) { return *ptr(data, row, col); } __device__ inline void store(float4& v, int row, int col) { *(reinterpret_cast(ptr(data, row, col))) = v; } __device__ inline void store(float2& v, int row, int col) { *(reinterpret_cast(ptr(data, row, col))) = v; } __device__ inline void store(float& v, int row, int col) { *(reinterpret_cast(ptr(data, row, col))) = v; } template __device__ inline void store(T (&v)[N], int row, int col) { if constexpr (sizeof(T) * N == 4) { store(*(reinterpret_cast(&v[0])), row, col); } else if constexpr (sizeof(T) * N == 8) { store(*(reinterpret_cast(&v[0])), row, col); } else if constexpr (sizeof(T) * N == 16) { store(*(reinterpret_cast(&v[0])), row, col); } else { MLX_UNROLL for (int i = 0; i < N; i++) { *ptr(data, row, col + i) = v[i]; } } } }; /** * Load the tile from global memory by loading 16 bytes at a time and storing * them immediately. * * Can also be used as a fallback for architectures before sm_80. */ template __device__ inline void load(Tile& tile, const T* x, int N) { constexpr int NUM_THREADS = NUM_WARPS * 32; constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; const int row = threadIdx.x / NUM_LOADS_PER_ROW; const int col = threadIdx.x % NUM_LOADS_PER_ROW; x += row * N + col * ELEMENTS_PER_LOAD; MLX_UNROLL for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { float4 tmp; tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); } } /** * The asynchronous equivalent of load. * * Loads the tile from global memory by submitting a bunch of async copy * instructions. The copy won't start until commit is called and we don't have * a guarantee it will finish until wait is called. * * It should be used as follows * * load(...) * load(...) * cp_async_commit() * do_other_stuff() * cp_async_wait_all() * do_stuff_with_shmem() */ template __device__ inline void load_async(Tile& tile, uint32_t base_address, const T* x, int N) { constexpr int NUM_THREADS = NUM_WARPS * 32; constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; const int row = threadIdx.x / NUM_LOADS_PER_ROW; const int col = threadIdx.x % NUM_LOADS_PER_ROW; x += row * N + col * ELEMENTS_PER_LOAD; MLX_UNROLL for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { cp_async<16>( tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), x + i * STEP_ROWS * N); } } /** * Same as load_async but checks if we can load the row. * * NOTE: It should be changed to use a predicated cp async instead. */ template __device__ inline void load_async_safe( Tile& tile, uint32_t base_address, const T* x, int N, int max_rows) { constexpr int NUM_THREADS = NUM_WARPS * 32; constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; const int row = threadIdx.x / NUM_LOADS_PER_ROW; const int col = threadIdx.x % NUM_LOADS_PER_ROW; x += row * N + col * ELEMENTS_PER_LOAD; MLX_UNROLL for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { if (row + i * STEP_ROWS < max_rows) { cp_async<16>( tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), x + i * STEP_ROWS * N); } else { float4 tmp = {0, 0, 0, 0}; tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); } } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/steel/utils.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/steel/defines.cuh" namespace mlx::core::cu { /** * Copy bytes from the global memory address pointed to by x to the smem * address pointed to by row_address. * * A simple wrapper over the PTX. */ template __device__ inline void cp_async(uint32_t row_address, const T* x) { static_assert( N == 16 || N == 8 || N == 4, "cp.async is only supported for N in {4, 8, 16}."); #if defined(MLX_CUDA_SM_80_ENABLED) if constexpr (N == 16) { asm volatile( "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } else if constexpr (N == 8) { asm volatile( "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } else if constexpr (N == 4) { asm volatile( "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } #endif } /** * Submit all the previous async copies to be executed. */ __device__ inline void cp_async_commit() { #if defined(MLX_CUDA_SM_80_ENABLED) asm volatile("cp.async.commit_group;\n" ::); #endif } /** * Wait for all but N of the async copies to finish. */ template __device__ inline void cp_async_wait() { #if defined(MLX_CUDA_SM_80_ENABLED) if constexpr (N == 0) { asm volatile("cp.async.wait_all;\n" ::); } else { asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); } #endif } /** * Wait for all the async copies to finish. */ __device__ inline void cp_async_wait_all() { cp_async_wait<0>(); } /** * Extract ``bits`` bits from the 32 bit value. * * Single instruction shift and mask. */ template __device__ inline uint32_t extract_bits(uint32_t value, int start_bit) { static_assert( bits == 2 || bits == 4 || bits == 8, "extract_bits only supports 2, 4, 8 for now."); uint32_t result; if constexpr (bits == 2) { asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit)); } else if constexpr (bits == 4) { asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit)); } else if constexpr (bits == 8) { asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit)); } return result; } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/ternary.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/ternary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[i], b[i], c[i]); } } else { auto a_vec = load_vector(a, index); auto b_vec = load_vector(b, index); auto c_vec = load_vector(c, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); } store_vector(out, index, out_vec); } } template __global__ void ternary_g_nd( const bool* a, const T* b, const T* c, T* out, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides, const __grid_constant__ cuda::std::array c_strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto a_stride_x = a_strides[NDIM - 1]; auto b_stride_x = b_strides[NDIM - 1]; auto c_stride_x = c_strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), c_strides.data()); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); auto c_vec = load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template __global__ void ternary_g( const bool* a, const T* b, const T* c, T* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, const __grid_constant__ Strides c_strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto a_stride_x = a_strides[ndim - 1]; auto b_stride_x = b_strides[ndim - 1]; auto c_stride_x = c_strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx, c_idx] = elem_to_loc( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), c_strides.data(), ndim); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); auto c_vec = load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } } // namespace cu template void ternary_op_gpu_inplace( const std::vector& inputs, array& out, const Stream& s) { const auto& a = inputs[0]; const auto& b = inputs[1]; const auto& c = inputs[2]; if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); dispatch_all_types(out.dtype(), [&](auto type_tag) { using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::VectorVectorVector || topt == TernaryOpType::ScalarScalarScalar) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(DType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( cu::ternary_v, num_blocks, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), gpu_ptr(out), out.data_size()); }); } else { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; auto& c_strides = strides[2]; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::ternary_g_nd; if (work_per_thread == 4) { kernel = cu::ternary_g_nd; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides)); }); } else { auto kernel = cu::ternary_g; if (work_per_thread == 4) { kernel = cu::ternary_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides), ndim); } }); } }); } template void ternary_op_gpu( const std::vector& inputs, array& out, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; auto topt = get_ternary_op_type(a, b, c); auto& encoder = cu::get_command_encoder(s); set_ternary_op_output_data( a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); }); ternary_op_gpu_inplace(inputs, out, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Select::eval_gpu"); auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, s); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu) ================================================ FILE: mlx/backend/cuda/unary/abs.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Abs) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arccos.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcCos) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arccosh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcCosh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arcsin.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcSin) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arcsinh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcSinh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arctan.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcTan) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/arctanh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ArcTanh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/bitwise_invert.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(BitwiseInvert) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/ceil.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Ceil) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/conjugate.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Conjugate) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/cos.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Cos) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/cosh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Cosh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/erf.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Erf) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/erf_inv.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(ErfInv) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/exp.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Exp) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/expm1.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Expm1) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/floor.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Floor) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/imag.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Imag) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/log.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { void Log::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Log::eval_gpu"); auto& s = out.primitive().stream(); switch (base_) { case Base::e: unary_op_gpu(inputs, out, name(), s); break; case Base::two: unary_op_gpu(inputs, out, name(), s); break; case Base::ten: unary_op_gpu(inputs, out, name(), s); break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/log1p.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Log1p) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/logical_not.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(LogicalNot) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/negative.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Negative) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/real.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Real) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/round.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { void Round::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Round::eval_gpu"); assert(inputs.size() == 1); const auto& in = inputs[0]; auto& s = out.primitive().stream(); if (issubdtype(in.dtype(), inexact)) { unary_op_gpu(inputs, out, name(), s); } else { // No-op integer types out.copy_shared_buffer(in); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/sigmoid.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Sigmoid) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/sign.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Sign) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/sin.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Sin) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/sinh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Sinh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/sqrt.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { void Sqrt::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Sqrt::eval_gpu"); auto& s = out.primitive().stream(); if (recip_) { unary_op_gpu(inputs, out, "Rsqrt", s); } else { unary_op_gpu(inputs, out, "Sqrt", s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/square.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Square) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/tan.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Tan) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/tanh.cu ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/unary/unary.cuh" namespace mlx::core { UNARY_GPU(Tanh) } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/unary/unary.cuh ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(in[i]); } } else { auto in_vec = load_vector(in, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(in_vec[i]); } store_vector(out, index, out_vec); } } template __global__ void unary_g( const In* in, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto stride_x = strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); auto in_vec = load_vector(in + idx, index_x, shape_x, stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(in_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && !mlx::core::is_complex_v; } if (std::is_same_v) { return std::is_same_v && mlx::core::is_complex_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { return mlx::core::is_complex_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v) { return std::is_same_v && is_floating_v; } return false; } } // namespace cu template void unary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& in = inputs[0]; if (in.size() == 0) { return; } bool contig = in.flags().contiguous; bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { dispatch_bool(large, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; if (contig) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(OutType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( cu::unary_v, num_blocks, block_dims, gpu_ptr(in), gpu_ptr(out), out.data_size()); } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto ndim = shape.size(); int work_per_thread = 1; auto kernel = cu::unary_g; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4) { kernel = cu::unary_g; work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, gpu_ptr(in), gpu_ptr(out), rest, const_param(shape), const_param(strides), ndim); } }); } else { throw std::runtime_error( fmt::format( "Can not do unary op {} on input of {} with output of {}.", op, dtype_to_string(in.dtype()), dtype_to_string(out.dtype()))); } }); }); } template void unary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& encoder = cu::get_command_encoder(s); set_unary_output_data( inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); }); unary_op_gpu_inplace(inputs, out, op, s); } #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ nvtx3::scoped_range r(#func "::eval_gpu"); \ auto& s = out.primitive().stream(); \ unary_op_gpu(inputs, out, name(), s); \ } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" #include #include #include namespace mlx::core { void check_cublas_error(const char* name, cublasStatus_t err) { if (err != CUBLAS_STATUS_SUCCESS) { // TODO: Use cublasGetStatusString when it is widely available. throw std::runtime_error( fmt::format("{} failed with code: {}.", name, static_cast(err))); } } void check_cuda_error(const char* name, cudaError_t err) { if (err != cudaSuccess) { throw std::runtime_error( fmt::format("{} failed: {}", name, cudaGetErrorString(err))); } } void check_cuda_error(const char* name, CUresult err) { if (err != CUDA_SUCCESS) { const char* err_str = "Unknown error"; cuGetErrorString(err, &err_str); throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); } } void check_cudnn_error(const char* name, cudnnStatus_t err) { if (err != CUDNN_STATUS_SUCCESS) { throw std::runtime_error( fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); } } const char* dtype_to_cuda_type(const Dtype& dtype) { switch (dtype) { case bool_: return "bool"; case int8: return "int8_t"; case int16: return "int16_t"; case int32: return "int32_t"; case int64: return "int64_t"; case uint8: return "uint8_t"; case uint16: return "uint16_t"; case uint32: return "uint32_t"; case uint64: return "uint64_t"; case float16: return "__half"; case bfloat16: return "__nv_bfloat16"; case float32: return "float"; case float64: return "double"; case complex64: return "mlx::core::cu::complex64_t"; default: return "unknown"; } } CudaGraph::CudaGraph(cu::Device& device) { device.make_current(); CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0)); } void CudaGraph::end_capture(cudaStream_t stream) { CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); } void CudaGraphExec::instantiate(cudaGraph_t graph) { assert(handle_ == nullptr); CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); } CudaStream::CudaStream(cu::Device& device) { device.make_current(); CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking)); } void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) { if (workspace_size == 0) { return nullptr; } // Workspace allocation should not be captured. #ifndef NDEBUG cudaStreamCaptureStatus status; CHECK_CUDA_ERROR(cudaStreamIsCapturing(encoder.stream(), &status)); assert(status == cudaStreamCaptureStatusNone); #endif // Ensure workspace is 256-byte aligned. int nbytes = cuda::ceil_div(workspace_size, 256) * 256; array workspace(cu::malloc_async(nbytes, encoder), {nbytes}, int8); encoder.add_temporary(workspace); return gpu_ptr(workspace); } } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/utils.h ================================================ // Copyright © 2025 Apple Inc. // This file include utilities that are used by C++ code (i.e. .cpp files). #pragma once #include "mlx/array.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/cuda_utils.h" namespace mlx::core { template inline uint32_t max_occupancy_block_dim(T kernel) { int _, block_dim; if constexpr (std::is_same_v) { CHECK_CUDA_ERROR( cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); } else { CHECK_CUDA_ERROR( cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); } return block_dim; } template inline T* gpu_ptr(array& arr) { return reinterpret_cast( static_cast( static_cast(arr.buffer().ptr())->data) + arr.offset()); } // For const array, keep constness in pointer unless it is untyped. template inline std::conditional_t, void*, const T*> gpu_ptr( const array& arr) { return gpu_ptr(const_cast(arr)); } struct Dtype; // Convert Dtype to CUDA C++ types. const char* dtype_to_cuda_type(const Dtype& dtype); // Allocate an empty array and add it as temporary. void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); } // namespace mlx::core ================================================ FILE: mlx/backend/cuda/vector_types.cuh ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include namespace mlx::core::cu { template struct Vector2; template <> struct Vector2 { using type = double2; }; template <> struct Vector2 { using type = float2; }; template <> struct Vector2<__half> { using type = __half2; }; template <> struct Vector2<__nv_bfloat16> { using type = __nv_bfloat162; }; template using Vector2_t = typename Vector2::type; template struct Vector4 { T x, y, z, w; }; template using Vector4_t = Vector4; using bf16x4 = Vector4_t<__nv_bfloat16>; using fp16x4 = Vector4_t<__half>; using fp32x4 = Vector4_t; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/worker.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/device.h" namespace mlx::core::cu { Worker::Worker(Device& d) : signal_stream_(d), signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync), worker_(&Worker::thread_fn, this) {} Worker::~Worker() { { std::lock_guard lock(mtx_); stop_ = true; } cond_.notify_one(); worker_.join(); } void Worker::add_task(std::function task) { pending_tasks_.push_back(std::move(task)); } void Worker::signal(void* data) { auto w = static_cast(data); { std::lock_guard lock(w->mtx_); w->signaled_batch_++; } w->cond_.notify_one(); } void Worker::commit(cudaStream_t stream) { // Move pending tasks into tasks if (pending_tasks_.empty()) { return; } { std::lock_guard lock(mtx_); // Move pending tasks into ready tasks worker_tasks_[++committed_batch_] = std::move(pending_tasks_); } signal_event_.record(stream); signal_event_.wait(signal_stream_); CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this)); } void Worker::thread_fn() { while (!stop_) { uint64_t current_batch = 0; Tasks tasks; { std::unique_lock lk(mtx_); cond_.wait(lk, [this, ¤t_batch] { return this->signaled_batch_ > current_batch || this->stop_; }); current_batch = signaled_batch_; auto end = worker_tasks_.upper_bound(current_batch); for (auto it = worker_tasks_.begin(); it != end; ++it) { if (tasks.empty()) { tasks = std::move(it->second); } else { std::move( it->second.begin(), it->second.end(), std::back_inserter(tasks)); } } worker_tasks_.erase(worker_tasks_.begin(), end); } // Make sure tasks are cleared before the next wait for (int i = 0; i < tasks.size(); ++i) { auto task = std::move(tasks[i]); task(); } } } } // namespace mlx::core::cu ================================================ FILE: mlx/backend/cuda/worker.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/event.h" #include #include #include #include #include namespace mlx::core::cu { // Run tasks in worker thread, synchronized with cuda stream. class Worker { public: explicit Worker(Device& d); ~Worker(); Worker(const Worker&) = delete; Worker& operator=(const Worker&) = delete; // Add a pending |task| that will run when consumed or commited. void add_task(std::function task); // Inform worker thread to run current batches after kernels in |stream| // finish running. void commit(cudaStream_t stream); private: static void signal(void*); void thread_fn(); std::mutex mtx_; std::condition_variable cond_; uint64_t committed_batch_{0}; uint64_t signaled_batch_{0}; // Cuda stream and event for signaling kernel completion. CudaStream signal_stream_; CudaEvent signal_event_; bool stop_{false}; // Tasks are put in |pending_tasks_| first, and then moved to // |worker_tasks_| when end_batch() is called. using Tasks = std::vector>; Tasks pending_tasks_; std::map worker_tasks_; std::thread worker_; }; } // namespace mlx::core::cu ================================================ FILE: mlx/backend/gpu/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) ================================================ FILE: mlx/backend/gpu/copy.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" #include #include namespace mlx::core { void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); } void copy_gpu_inplace( const array& in, array& out, CopyType ctype, const Stream& s) { assert(in.shape() == out.shape()); return copy_gpu_inplace( in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); } void copy_gpu_inplace( const array& in, array& out, const Strides& i_strides, int64_t i_offset, CopyType ctype, const Stream& s) { assert(in.shape() == out.shape()); return copy_gpu_inplace( in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); } array contiguous_copy_gpu(const array& arr, const Stream& s) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); return arr_copy; } array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { int ndim = x.ndim(); if (start_axis < 0) { start_axis += ndim; } if (end_axis < 0) { end_axis += ndim; } start_axis = std::max(0, start_axis); end_axis = std::min(ndim - 1, end_axis); return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); } array reshape_in_eval(const array& x, Shape shape, Stream s) { array out(std::move(shape), x.dtype(), nullptr, {}); reshape_gpu(x, out, s); return out; } array transpose_in_eval(const array& x, const std::vector& axes) { Shape shape(axes.size()); Strides strides(axes.size()); for (int i = 0; i < axes.size(); ++i) { shape[i] = x.shape(axes[i]); strides[i] = x.strides(axes[i]); } auto [data_size, row_contiguous, col_contiguous] = check_contiguity(shape, strides); bool contiguous = data_size == x.data_size(); array out(std::move(shape), x.dtype(), nullptr, {}); out.copy_shared_buffer( x, std::move(strides), {contiguous, row_contiguous, col_contiguous}, x.data_size()); return out; } array swapaxes_in_eval(const array& x, int axis1, int axis2) { int ndim = x.ndim(); if (axis1 < 0) { axis1 += ndim; } if (axis2 < 0) { axis2 += ndim; } std::vector axes(ndim); std::iota(axes.begin(), axes.end(), 0); std::swap(axes[axis1], axes[axis2]); return transpose_in_eval(x, axes); } } // namespace mlx::core ================================================ FILE: mlx/backend/gpu/copy.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/backend/common/copy.h" #include "mlx/stream.h" #include #include namespace mlx::core { // Generic copy inplace void copy_gpu_inplace( const array& in, array& out, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream& s, std::optional dynamic_i_offset = std::nullopt, std::optional dynamic_o_offset = std::nullopt); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu_inplace( const array& in, array& out, CopyType ctype, const Stream& s); void copy_gpu_inplace( const array& in, array& out, const Strides& i_strides, int64_t i_offset, CopyType ctype, const Stream& s); // Fill the output with the scalar val void fill_gpu(const array& val, array& out, const Stream& s); // Return a contiguous array with same shape that copies the data of |arr|. array contiguous_copy_gpu(const array& arr, const Stream& s); // Copy data from |in| and transpose to |out|'s shape. void reshape_gpu(const array& in, array& out, Stream s); // Like the normal ops but safe to call in eval_gpu. array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); array reshape_in_eval(const array& x, Shape shape, Stream s); array transpose_in_eval(const array& x, const std::vector& axes); array swapaxes_in_eval(const array& x, int axis1, int axis2); } // namespace mlx::core ================================================ FILE: mlx/backend/gpu/device_info.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include #include #include #include "mlx/api.h" namespace mlx::core::gpu { MLX_API bool is_available(); /** * Get the number of available GPU devices. */ MLX_API int device_count(); /** * Get information about a GPU device. * * Returns a map of device properties. Keys vary by backend: * - device_name (string): Device name * - architecture (string): Architecture identifier * - total_memory/memory_size (size_t): Total device memory * - free_memory (size_t): Available memory (CUDA only) * - uuid (string): Device UUID (CUDA only) * - pci_bus_id (string): PCI bus ID (CUDA only) * - compute_capability_major/minor (size_t): Compute capability (CUDA only) */ MLX_API const std::unordered_map>& device_info(int device_index = 0); } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/gpu/eval.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/array.h" #include "mlx/stream.h" namespace mlx::core::gpu { void new_stream(Stream stream); void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/gpu/primitives.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/primitives.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #if defined(MLX_USE_CUDA) #include #endif #include #if defined(MLX_USE_CUDA) #define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message) #else #define MLX_PROFILER_RANGE(message) #endif namespace mlx::core { void AsStrided::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("AsStrided::eval_gpu"); eval(inputs, out); } void AsType::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("AsType::eval_gpu"); CopyType ctype = inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; copy_gpu(inputs[0], out, ctype); } void Broadcast::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Broadcast::eval_gpu"); eval(inputs, out); } void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); eval(inputs, out); } void Concatenate::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Concatenate::eval_gpu"); concatenate_gpu(inputs, out, axis_, stream()); } void Contiguous::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Contiguous::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; constexpr size_t extra_bytes = 16384; if (in.buffer_size() <= out.nbytes() + extra_bytes && (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy_gpu(in, out, CopyType::General); } } void Copy::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Copy::eval_gpu"); eval(inputs, out); } void CustomTransforms::eval_gpu( const std::vector& inputs, std::vector& outputs) { MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); eval(inputs, outputs); } void Depends::eval_gpu( const std::vector& inputs, std::vector& outputs) { MLX_PROFILER_RANGE("Depends::eval_gpu"); eval(inputs, outputs); } void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("DynamicSlice::eval_gpu"); if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; auto& start = inputs[1]; out.set_data(allocator::malloc(out.nbytes())); auto s = stream(); auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); copy_gpu_inplace( /* const array& src = */ in, /* array& dst = */ out, /* const Shape& data_shape = */ out.shape(), /* const Strides& i_strides = */ in.strides(), /* const Strides& o_strides = */ out.strides(), /* int64_t i_offset = */ 0, /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::GeneralGeneral, /* const Stream& s = */ s, /* std::optional dynamic_i_offset = */ std::move(in_offset), /* std::optional dynamic_o_offset = */ std::nullopt); } void DynamicSliceUpdate::eval_gpu( const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu"); if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; auto& upd = inputs[1]; auto& start_indices = inputs[2]; if (upd.size() == 0) { out.copy_shared_buffer(in); return; } // Copy or donate input to output auto s = stream(); auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); auto out_offset = compute_dynamic_offset(start_indices, out.strides(), axes_, s); copy_gpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const Shape& data_shape = */ upd.shape(), /* const Strides& i_strides = */ upd.strides(), /* const Strides& o_strides = */ out.strides(), /* int64_t i_offset = */ 0, /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::GeneralGeneral, /* const Stream& s = */ s, /* std::optional dynamic_i_offset = */ std::nullopt, /* std::optional dynamic_o_offset = */ std::move(out_offset)); } void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); eval(inputs, out); } void Full::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Full::eval_gpu"); auto in = inputs[0]; CopyType ctype; if (in.data_size() == 1) { ctype = CopyType::Scalar; } else if (in.flags().contiguous) { ctype = CopyType::Vector; } else { ctype = CopyType::General; } copy_gpu(in, out, ctype); } void Flatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Flatten::eval_gpu"); reshape_gpu(inputs[0], out, stream()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); eval(inputs, out); } void Pad::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Pad::eval_gpu"); // Inputs must be base input array and scalar val array assert(inputs.size() == 2); auto& in = inputs[0]; auto& val = inputs[1]; // Padding value must be a scalar assert(val.size() == 1); // Padding value, input and output must be of the same type assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); pad_gpu(in, val, out, axes_, low_pad_size_, stream()); } void Reshape::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Reshape::eval_gpu"); reshape_gpu(inputs[0], out, stream()); } void Split::eval_gpu( const std::vector& inputs, std::vector& outputs) { MLX_PROFILER_RANGE("Split::eval_gpu"); eval(inputs, outputs); } void Slice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Slice::eval_gpu"); assert(inputs.size() == 1); if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; slice_gpu(in, out, start_indices_, strides_, stream()); } void Squeeze::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Squeeze::eval_gpu"); eval(inputs, out); } void StopGradient::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("StopGradient::eval_gpu"); eval(inputs, out); } void Transpose::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Transpose::eval_gpu"); eval(inputs, out); } void Unflatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Unflatten::eval_gpu"); reshape_gpu(inputs[0], out, stream()); } void View::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("View::eval_gpu"); auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); auto obytes = size_of(out.dtype()); // Conditions for buffer copying (disjunction): // - type size is the same // - type size is smaller and the last axis is contiguous // - the entire array is row contiguous if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || in.flags().row_contiguous) { auto strides = in.strides(); for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { strides[i] *= ibytes; strides[i] /= obytes; } out.copy_shared_buffer( in, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array(in.shape(), in.dtype(), nullptr, {}); tmp.set_data(allocator::malloc(tmp.nbytes())); copy_gpu_inplace(in, tmp, CopyType::General, stream()); auto flags = out.flags(); flags.contiguous = true; flags.row_contiguous = true; auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/gpu/scan.h ================================================ #pragma once #include "mlx/array.h" #include "mlx/primitives.h" namespace mlx::core { void scan_gpu_inplace( array in, array& out, Scan::ReduceType reduce_type, int axis, bool reverse, bool inclusive, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/gpu/slicing.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/slicing.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" namespace mlx::core { void slice_gpu( const array& in, array& out, const Shape& start_indices, const Shape& strides, const Stream&) { slice(in, out, start_indices, strides); } void pad_gpu( const array& in, const array& val, array& out, const std::vector& axes, const Shape& low_pad_size, const Stream& s) { // Fill output with val fill_gpu(val, out, s); // Find offset for start of input values size_t data_offset = 0; for (int i = 0; i < axes.size(); i++) { auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; data_offset += out.strides()[ax] * low_pad_size[i]; } // Extract slice from output where input will be pasted array out_slice(in.shape(), out.dtype(), nullptr, {}); out_slice.copy_shared_buffer( out, out.strides(), out.flags(), out_slice.size(), data_offset); // Copy input values into the slice copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); } } // namespace mlx::core ================================================ FILE: mlx/backend/gpu/slicing.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void slice_gpu( const array& in, array& out, const Shape& start_indices, const Shape& strides, const Stream& s); void concatenate_gpu( const std::vector& inputs, array& out, int axis, const Stream& s); void pad_gpu( const array& in, const array& val, array& out, const std::vector& axes, const Shape& low_pad_size, const Stream& s); array compute_dynamic_offset( const array& indices, const Strides& strides, const std::vector& axes, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/metal/CMakeLists.txt ================================================ function(make_jit_source SRC_FILE) # This function takes a metal header file, runs the C preprocessesor on it, # and makes the processed contents available as a string in a C++ function # mlx::core::metal::${SRC_NAME}() # # To use the function, declare it in jit/includes.h and include # jit/includes.h. # # Additional arguments to this function are treated as dependencies in the # Cmake build system. get_filename_component(SRC_NAME ${SRC_FILE} NAME) add_custom_command( OUTPUT jit/${SRC_NAME}.cpp COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} ${SRC_FILE} DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN}) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_dependencies(mlx ${SRC_NAME}) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) endfunction(make_jit_source) make_jit_source(utils kernels/bf16.h kernels/bf16_math.h kernels/complex.h kernels/defines.h kernels/logging.h) make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(indexing/scatter kernels/indexing/indexing.h) make_jit_source(indexing/masked_scatter) make_jit_source(indexing/gather kernels/indexing/indexing.h) make_jit_source(indexing/gather_front kernels/indexing/indexing.h) make_jit_source(indexing/gather_axis) make_jit_source(indexing/scatter_axis) make_jit_source(hadamard) if(MLX_METAL_JIT) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp) make_jit_source(arange) make_jit_source(copy) make_jit_source(unary) make_jit_source(binary) make_jit_source(binary_two) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(logsumexp) make_jit_source(ternary) make_jit_source(softmax) make_jit_source(scan) make_jit_source(sort) make_jit_source( reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h) make_jit_source( steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h kernels/steel/gemm/mma.h kernels/steel/gemm/params.h kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_segmented) make_jit_source( steel/conv/conv kernels/steel/utils.h kernels/steel/defines.h kernels/steel/gemm/mma.h kernels/steel/gemm/transforms.h kernels/steel/conv/params.h kernels/steel/conv/loader.h kernels/steel/conv/loaders/loader_channel_l.h kernels/steel/conv/loaders/loader_channel_n.h) make_jit_source(steel/conv/kernels/steel_conv) make_jit_source(steel/conv/kernels/steel_conv_3d) make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h kernels/steel/conv/loaders/loader_general.h) make_jit_source(quantized_utils) make_jit_source(quantized kernels/quantized_utils.h) make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h kernels/fp4.h) make_jit_source(gemv_masked) make_jit_source(steel/attn/kernels/steel_attention) make_jit_source( steel/gemm/gemm_nax kernels/steel/utils.h kernels/steel/gemm/nax.h kernels/steel/gemm/params.h kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused_nax) make_jit_source(steel/gemm/kernels/steel_gemm_gather_nax) make_jit_source(steel/gemm/kernels/steel_gemm_splitk_nax) make_jit_source(quantized_nax kernels/quantized_utils.h) make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h kernels/fp4.h) make_jit_source(steel/attn/kernels/steel_attention_nax) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) endif() target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) if(NOT MLX_METAL_PATH) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) target_compile_definitions(mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") ================================================ FILE: mlx/backend/metal/allocator.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/resident.h" #include "mlx/memory.h" #include #include #include namespace mlx::core { constexpr size_t resource_options = MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked; namespace allocator { Allocator& allocator() { return metal::allocator(); } void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } return static_cast(ptr_)->contents(); } } // namespace allocator namespace metal { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), buffer_cache_( vm_page_size, [](MTL::Buffer* buf) { return buf->length(); }, [this](MTL::Buffer* buf) { if (!buf->heap()) { residency_set_.erase(buf); } buf->release(); }), residency_set_(device_) { auto pool = metal::new_scoped_memory_pool(); const auto& info = gpu::device_info(0); auto memsize = std::get(info.at("memory_size")); auto max_rec_size = std::get(info.at("max_recommended_working_set_size")); resource_limit_ = std::get(info.at("resource_limit")); block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize); gc_limit_ = std::min(static_cast(0.95 * max_rec_size), block_limit_); max_pool_size_ = block_limit_; device(mlx::core::Device::gpu) .set_residency_set(residency_set_.mtl_residency_set()); bool is_vm = std::get(info.at("device_name")) == "Apple Paravirtual device"; if (is_vm) { return; } auto heap_desc = MTL::HeapDescriptor::alloc()->init(); heap_desc->setResourceOptions(resource_options); heap_desc->setSize(heap_size_); heap_ = device_->newHeap(heap_desc); heap_desc->release(); residency_set_.insert(heap_); } MetalAllocator::~MetalAllocator() { auto pool = metal::new_scoped_memory_pool(); if (heap_) { heap_->release(); } buffer_cache_.clear(); } size_t MetalAllocator::set_cache_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, max_pool_size_); return limit; }; size_t MetalAllocator::set_memory_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, block_limit_); gc_limit_ = std::min( block_limit_, static_cast(0.95 * device_->recommendedMaxWorkingSetSize())); return limit; }; size_t MetalAllocator::get_memory_limit() { return block_limit_; } size_t MetalAllocator::set_wired_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, wired_limit_); residency_set_.resize(wired_limit_); return limit; }; Buffer MetalAllocator::malloc(size_t size) { // Metal doesn't like empty buffers if (size == 0) { return Buffer{nullptr}; } // More helpful message if maximum buffer length is exceeded if (size > device_->maxBufferLength()) { std::ostringstream msg; msg << "[metal::malloc] Attempting to allocate " << size << " bytes which is greater than" << " the maximum allowed buffer size of " << device_->maxBufferLength() << " bytes."; throw std::runtime_error(msg.str()); } // Align up memory if (size > vm_page_size) { size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); } // Try the cache std::unique_lock lk(mutex_); MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { size_t mem_required = get_active_memory() + get_cache_memory() + size; auto pool = metal::new_scoped_memory_pool(); // If we have a lot of memory pressure try to reclaim memory from the cache if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { num_resources_ -= buffer_cache_.release_cached_buffers(mem_required - gc_limit_); } // Allocate new buffer if needed if (num_resources_ >= resource_limit_) { std::ostringstream msg; msg << "[metal::malloc] Resource limit (" << resource_limit_ << ") exceeded."; throw std::runtime_error(msg.str()); } lk.unlock(); if (size < small_size_ && heap_) { buf = heap_->newBuffer(size, resource_options); } if (!buf) { buf = device_->newBuffer(size, resource_options); } if (!buf) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; throw std::runtime_error(msg.str()); } lk.lock(); num_resources_++; if (!buf->heap()) { residency_set_.insert(buf); } } active_memory_ += buf->length(); peak_memory_ = std::max(peak_memory_, active_memory_); // Maintain the cache below the requested limit if (get_cache_memory() > max_pool_size_) { auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.release_cached_buffers( get_cache_memory() - max_pool_size_); } return Buffer{static_cast(buf)}; } void MetalAllocator::clear_cache() { std::unique_lock lk(mutex_); auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.clear(); } void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); if (buf == nullptr) { return; } std::unique_lock lk(mutex_); active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { num_resources_--; if (!buf->heap()) { residency_set_.erase(buf); } lk.unlock(); auto pool = metal::new_scoped_memory_pool(); buf->release(); } } size_t MetalAllocator::size(Buffer buffer) const { return static_cast(buffer.ptr())->length(); } Buffer MetalAllocator::make_buffer(void* ptr, size_t size) { auto buf = device_->newBuffer(ptr, size, resource_options, nullptr); if (!buf) { return Buffer{nullptr}; } std::unique_lock lk(mutex_); residency_set_.insert(buf); active_memory_ += buf->length(); peak_memory_ = std::max(peak_memory_, active_memory_); num_resources_++; return Buffer{static_cast(buf)}; } void MetalAllocator::release(Buffer buffer) { auto buf = static_cast(buffer.ptr()); if (buf == nullptr) { return; } std::unique_lock lk(mutex_); active_memory_ -= buf->length(); num_resources_--; residency_set_.erase(buf); lk.unlock(); auto pool = metal::new_scoped_memory_pool(); buf->release(); } MetalAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of MetalAllocator // will not be called on exit and buffers in the cache will be leaked. This // can save some time at program exit. static MetalAllocator* allocator_ = new MetalAllocator; return *allocator_; } } // namespace metal size_t set_cache_limit(size_t limit) { return metal::allocator().set_cache_limit(limit); } size_t set_memory_limit(size_t limit) { return metal::allocator().set_memory_limit(limit); } size_t get_memory_limit() { return metal::allocator().get_memory_limit(); } size_t set_wired_limit(size_t limit) { if (limit > std::get( gpu::device_info(0).at("max_recommended_working_set_size"))) { throw std::invalid_argument( "[metal::set_wired_limit] Setting a wired limit larger than " "the maximum working set size is not allowed."); } return metal::allocator().set_wired_limit(limit); } size_t get_active_memory() { return metal::allocator().get_active_memory(); } size_t get_peak_memory() { return metal::allocator().get_peak_memory(); } void reset_peak_memory() { metal::allocator().reset_peak_memory(); } size_t get_cache_memory() { return metal::allocator().get_cache_memory(); } void clear_cache() { return metal::allocator().clear_cache(); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/allocator.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/resident.h" namespace mlx::core::metal { using allocator::Buffer; class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; virtual Buffer make_buffer(void* ptr, size_t size) override; virtual void release(Buffer buffer) override; size_t get_active_memory() { return active_memory_; }; size_t get_peak_memory() { return peak_memory_; }; void reset_peak_memory() { std::unique_lock lk(mutex_); peak_memory_ = 0; }; size_t get_cache_memory() { return buffer_cache_.cache_size(); }; size_t set_cache_limit(size_t limit); size_t set_memory_limit(size_t limit); size_t get_memory_limit(); size_t set_wired_limit(size_t limit); void clear_cache(); private: MTL::Device* device_; // The size of allocations which go on the heap until it is full. This size // is chosen because it is the actual minimum size of a buffer allocated from // the heap, a heap can have at most heap.size() / 256 buffers. static constexpr int small_size_ = 256; static constexpr int heap_size_ = 1 << 20; MTL::Heap* heap_; MetalAllocator(); ~MetalAllocator(); friend MetalAllocator& allocator(); // Caching allocator BufferCache buffer_cache_; ResidencySet residency_set_; // Allocation stats size_t block_limit_; size_t gc_limit_; size_t active_memory_{0}; size_t peak_memory_{0}; size_t max_pool_size_; size_t wired_limit_{0}; size_t num_resources_{0}; size_t resource_limit_{0}; std::mutex mutex_; }; MetalAllocator& allocator(); } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/binary.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ binary_op_gpu(inputs, out, name()); \ } #define BINARY_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ binary_op_gpu(inputs, outputs, name()); \ } namespace mlx::core { std::string get_kernel_name( BinaryOpType bopt, const char* op, const array& a, bool large, int ndim, int work_per_thread) { std::string kname; switch (bopt) { case BinaryOpType::ScalarScalar: kname = "ss"; break; case BinaryOpType::ScalarVector: kname = "sv"; break; case BinaryOpType::VectorScalar: kname = "vs"; break; case BinaryOpType::VectorVector: kname = "vv"; break; case BinaryOpType::General: kname = "g"; if (ndim <= 3) { kname += std::to_string(ndim); } else { concatenate(kname, "n", std::to_string(work_per_thread)); } if (large) { kname += "large"; } break; } if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) { if (large) { kname += "2"; } else if (work_per_thread > 1) { kname += "n"; } } concatenate(kname, "_", op, type_to_name(a)); return kname; } void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); auto& out = outputs[0]; if (out.size() == 0) { return; } // Try to collapse contiguous dims auto maybe_collapse = [bopt, &a, &b, &out]() { if (bopt == BinaryOpType::General) { auto [shape, strides] = collapse_contiguous_dims(a, b, out); return std::make_tuple(shape, strides[0], strides[1], strides[2]); } else { decltype(a.strides()) e{}; return std::make_tuple(decltype(a.shape()){}, e, e, e); } }; auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); bool large; auto ndim = shape.size(); int work_per_thread; if (bopt == BinaryOpType::General) { large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out.size() > INT32_MAX; work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; work_per_thread = get_work_per_thread(a.dtype(), out.data_size()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); auto& d = metal::device(s.device); auto kernel = outputs.size() == 2 ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int arg_idx = 0; compute_encoder.set_input_array(a, arg_idx++); compute_encoder.set_input_array(b, arg_idx++); compute_encoder.set_output_array(outputs[0], arg_idx++); if (outputs.size() == 2) { compute_encoder.set_output_array(outputs[1], arg_idx++); } auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (bopt == BinaryOpType::General) { // Launch up to 3D grid of threads size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = out.size() / (dim0 * dim1); if (ndim > 3) { compute_encoder.set_vector_bytes(shape, arg_idx++); compute_encoder.set_vector_bytes(strides_a, arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++); compute_encoder.set_bytes(ndim, arg_idx++); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D compute_encoder.set_vector_bytes(strides_a, arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++); } if (thread_group_size != 1024) { throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims); } } void binary_op_gpu( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); binary_op_gpu_inplace(inputs, outputs, op, s); } void binary_op_gpu( const std::vector& inputs, std::vector& outputs, const char* op) { auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, op, s); } void binary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { std::vector outputs = {out}; binary_op_gpu_inplace(inputs, outputs, op, s); } void binary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); binary_op_gpu_inplace(inputs, out, op, s); } void binary_op_gpu( const std::vector& inputs, array& out, const char* op) { auto& s = out.primitive().stream(); binary_op_gpu(inputs, out, op, s); } BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) BINARY_GPU_MULTI(DivMod) BINARY_GPU(Remainder) BINARY_GPU(Equal) BINARY_GPU(Greater) BINARY_GPU(GreaterEqual) BINARY_GPU(Less) BINARY_GPU(LessEqual) BINARY_GPU(LogicalAnd) BINARY_GPU(LogicalOr) BINARY_GPU(LogAddExp) BINARY_GPU(Maximum) BINARY_GPU(Minimum) BINARY_GPU(Multiply) BINARY_GPU(NotEqual) BINARY_GPU(Power) BINARY_GPU(Subtract) void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Or: binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Xor: binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::LeftShift: binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::RightShift: binary_op_gpu(inputs, out, name()); break; } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/binary.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void binary_op_gpu( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s); void binary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/metal/compiled.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { inline void build_kernel( std::string& os, const std::string& kernel_name, const std::vector& inputs, const std::vector& outputs, const std::vector& tape, const std::function& is_constant, bool contiguous, int ndim, bool dynamic_dims, bool use_big_index = false, int work_per_thread = 1) { NodeNamer namer; bool add_indices = false; int cnt = 0; // Start the kernel os += fmt::format( "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list if (is_constant(i)) { continue; } const auto& x = inputs[i]; auto& xname = namer.get_name(x); // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { add_indices = true; } os += fmt::format( " device const {0}* {1} [[buffer({2})]],\n", get_type_string(x.dtype()), xname, cnt++); } std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); } // Add the output arguments for (auto& x : outputs) { os += fmt::format( " device {0}* {1} [[buffer({2})]],\n", get_type_string(x.dtype()), namer.get_name(x), cnt++); } // Add output strides and shape to extract the indices. if (!contiguous) { os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); } else { os += fmt::format( " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); } // The thread index in the whole grid os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; } else if (contiguous) { os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); os += fmt::format( " {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } else { os += fmt::format( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } if (work_per_thread > 1 && contiguous) { os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; } // Read constant / contiguous inputs in tmps std::vector nc_inputs; for (int i = 0; i < inputs.size(); ++i) { auto& x = inputs[i]; auto& xname = namer.get_name(x); if (is_constant(i)) { auto type_str = get_type_string(x.dtype()); std::ostringstream ss; print_constant(ss, x); os += fmt::format( " auto tmp_{0} = static_cast<{1}>({2});\n", xname, get_type_string(x.dtype()), ss.str()); } else if (is_scalar(x)) { os += fmt::format( " {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname); } else if (contiguous) { os += fmt::format( " {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname); } else { nc_inputs.push_back(x); } } // Initialize the indices for non-contiguous inputs for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { int offset = i * ndim; os += fmt::format("elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); } else if (ndim == 2) { int offset = i * ndim; os += fmt::format( "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n", idx_type, offset); } else if (ndim == 3) { int offset = i * ndim; os += fmt::format( "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset); } else if (!dynamic_dims) { int offset = (i + 1) * ndim; os += fmt::format( "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n", idx_type, offset - 1, offset - 2); } else { os += fmt::format( "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n", idx_type, i); } } if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) { os += " uint zpos = pos.z;\n"; if (dynamic_dims) { os += " for (int d = ndim - 3; d >= 0; --d) {\n"; } else { os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); } os += " uint l = zpos % output_shape[d];\n"; for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" index_{0} += ", xname); if (dynamic_dims) { os += fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i); } else { os += fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim); } } os += " zpos /= output_shape[d];\n }\n"; } // Open per-thread loop if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } // Read non-contiguous inputs into tmps for (int i = 0; i < nc_inputs.size(); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); os += fmt::format( " {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname); } // Actually write the computation for (auto& x : tape) { os += fmt::format( " {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x)); if (is_static_cast(x.primitive())) { os += fmt::format( "static_cast<{0}>(tmp_{1});\n", get_type_string(x.dtype()), namer.get_name(x.inputs()[0])); } else { os += x.primitive().name(); os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); } os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); } } // Write the outputs from tmps for (auto& x : outputs) { os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); } // Increment indices and close per thread loop if (work_per_thread > 1) { for (int i = 0; i < nc_inputs.size(); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); if (!dynamic_dims) { os += fmt::format( " index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1); } else { os += fmt::format( " index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i); } } os += " index++;\n }\n"; } // Finish the kernel os += "}\n"; if (cnt > 31) { std::ostringstream msg; msg << "[compile] Too many inputs/outputs fused in the Metal Compiled " << "primitive which exhausted the available argument buffers for " << "the kernel. Please file an issue with the function that results " << "in this error. The name of the kernel is '" << kernel_name << "'"; throw std::runtime_error(msg.str()); } } void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); build_kernel( kernel, kernel_lib_ + "_contiguous", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, /* work_per_thread = */ 1); if (work_per_thread > 1) { build_kernel( kernel, kernel_lib_ + "_contiguous_n", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, /* work_per_thread = */ work_per_thread); } build_kernel( kernel, kernel_lib_ + "_contiguous_large", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ true, /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, kernel_lib_ + "_strided_" + std::to_string(i), inputs_, outputs_, tape_, is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, /* use_big_index = */ false, /* work_per_thread = */ i > 3 ? 2 : 1); if (i > 1) { build_kernel( kernel, kernel_lib_ + "_strided_" + std::to_string(i) + "_large", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, /* use_big_index = */ true, /* work_per_thread = */ i > 3 ? 4 : 1); } } build_kernel( kernel, kernel_lib_ + "_strided_dynamic", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, /* use_big_index = */ false, /* work_per_thread = */ 2); build_kernel( kernel, kernel_lib_ + "_strided_dynamic_large", inputs_, outputs_, tape_, is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, /* use_big_index = */ true, /* work_per_thread = */ 4); return kernel; }); // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); // Whether to use large index. bool large = compiled_use_large_index(inputs, outputs, contiguous); // Get the kernel from the lib int ndim = shape.size(); bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); int work_per_thread = 1; if (!contiguous) { if (dynamic) { kernel_name += "dynamic"; } else { kernel_name += std::to_string(shape.size()); } work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; } else { work_per_thread = get_work_per_thread(outputs[0].dtype(), outputs[0].data_size()); if (work_per_thread > 1 && !large) { kernel_name += "_n"; } } if (large) { kernel_name += "_large"; } auto kernel = d.get_kernel(kernel_name, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); // Put the inputs in int cnt = 0; int stride_idx = 1; // idx 0 is the output strides Strides in_strides; for (int i = 0; i < inputs.size(); i++) { if (is_constant_(i)) { continue; } auto& x = inputs[i]; compute_encoder.set_input_array(x, cnt++); if (!contiguous && !is_scalar(x)) { in_strides.insert( in_strides.end(), strides[stride_idx].begin(), strides[stride_idx].end()); stride_idx++; } } if (!in_strides.empty()) { compute_encoder.set_vector_bytes(in_strides, cnt++); } compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); // Put the outputs in for (auto& x : outputs) { compute_encoder.set_output_array(x, cnt++); } // Put the output shape and strides in if (!contiguous) { compute_encoder.set_vector_bytes(shape, cnt++); } else { auto size = outputs[0].data_size(); if (large) { compute_encoder.set_bytes(size, cnt++); } else { compute_encoder.set_bytes(size, cnt++); } } // Put the number of dims in if it is dynamic if (dynamic) { compute_encoder.set_bytes(ndim, cnt++); } // Launch the kernel if (contiguous) { size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); MTL::Size grid_dims = large ? get_2d_grid_dims( outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; if (thread_group_size == 1024) { pow2 = 10; } else if (thread_group_size > 512) { pow2 = 9; } else { throw std::runtime_error("[Metal::compiled] Must use > 512 sized block"); } auto group_dims = get_block_dims(dim0, dim1, rest, pow2); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/conv.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" using namespace mlx::steel; namespace mlx::core { namespace { inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (x.flags().row_contiguous) { return x; } auto result = contiguous_copy_gpu(x, s); d.add_temporary(result, s.index); return result; } template void explicit_gemm_conv_ND_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams& conv_params) { // Get gemm shapes int implicit_M = out.size() / conv_params.O; int implicit_K = wt.size() / conv_params.O; int implicit_N = conv_params.O; // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::string kname; kname.reserve(32); concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel size_t tgp_x = std::min(conv_params.C, 64); tgp_x = 32 * ((tgp_x + 32 - 1) / 32); size_t tgp_y = 256 / tgp_x; MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); MTL::Size group_dims = MTL::Size( std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); compute_encoder.dispatch_threads(grid_dims, group_dims); // Reshape weight Shape wt_reshape{implicit_K, implicit_N}; Strides wt_restride{1, implicit_K}; array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {}); auto wt_flags = wt.flags(); wt_flags.row_contiguous = false; wt_flags.col_contiguous = true; wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size()); // Perform gemm std::vector copies = {in_unfolded}; return steel_matmul( s, d, /*a = */ in_unfolded, /*b = */ wt_reshaped, /*c = */ out, /*M = */ implicit_M, /*N = */ implicit_N, /*K = */ implicit_K, /*batch_size_out = */ 1, /*a_cols = */ implicit_K, /*b_cols = */ implicit_K, /*a_transposed = */ false, /*b_transposed = */ true, /*copies = */ copies); } template void explicit_gemm_conv_group_ND_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams& conv_params) { const int groups = conv_params.groups; const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Get gemm shapes const int implicit_M = out.size() / conv_params.O; const int implicit_K = wt.size() / conv_params.O; const int implicit_N = O_per_group; int kernel_size = 1; for (int i = 0; i < N; ++i) { kernel_size *= conv_params.wS[i]; } // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K * groups}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::string kname; kname.reserve(32); concatenate( kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel size_t tgp_x = std::min(conv_params.C, 64); tgp_x = 32 * ((tgp_x + 32 - 1) / 32); size_t tgp_y = 256 / tgp_x; MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); MTL::Size group_dims = MTL::Size( std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); compute_encoder.dispatch_threads(grid_dims, group_dims); // Transpose kernel weights so that we can slice them by contiguous chunks // of channel groups. array wt_view( {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {}); wt_view.copy_shared_buffer( wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); // Materialize array wt_transpose = contiguous_copy_gpu(wt_view, s); // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( /* const Stream& s = */ s, /* Device& d = */ d, /* const array& a = */ in_unfolded, /* const array& b = */ wt_transpose, /* array& c = */ out, /* int M = */ implicit_M, /* int N = */ implicit_N, /* int K = */ implicit_K, /* int batch_size_out = */ groups, /* int lda = */ implicit_K * groups, /* int ldb = */ implicit_K, /* int ldd = */ implicit_N * groups, /* bool transpose_a = */ false, /* bool transpose_b = */ true, /* std::vector& copies = */ copies, /* Shape batch_shape = */ {1}, /* Strides batch_strides = */ {0}, /* int64_t A_batch_strides = */ int64_t(implicit_K), /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<2>& conv_params) { const int groups = conv_params.groups; const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group; // Determine block and warp tiles int wm = 2, wn = 2; int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32; int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32; int bk = 16; if (implicit_N <= 16) { bn = 8; wm = 4; wn = 1; } int tn = (implicit_N + bn - 1) / bn; int tm = (implicit_M + bm - 1) / bm; int swizzle_log = 0; // Fix small channel specialization int n_channel_specialization = 0; int channel_k_iters = ((C_per_group + bk - 1) / bk); int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters; if (C_per_group <= 2) { gemm_k_iters = (implicit_K + bk - 1) / bk; n_channel_specialization = C_per_group; } else if (C_per_group <= 4) { gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk; n_channel_specialization = C_per_group; } bool small_filter = (!n_channel_specialization) && (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16); // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); int ijw = conv_params.in_strides[2] * conv_params.kdil[1]; int ijh = conv_params.in_strides[1] * conv_params.kdil[0]; int inp_jump_w = sign * ijw; int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw); int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh - sign * (conv_params.wS[1] - 1) * ijw; // Build implicit gemm params ImplicitGemmConv2DParams gemm_params{ /* const int M = */ implicit_M, /* const int N = */ implicit_N, /* const int K = */ implicit_K, /* const int gemm_k_iterations = */ gemm_k_iters, /* const int inp_jump_w = */ inp_jump_w, /* const int inp_jump_h = */ inp_jump_h, /* const int inp_jump_c = */ inp_jump_c, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int swizzle_log = */ swizzle_log}; // Determine kernel std::string kname; kname.reserve(64); concatenate( kname, "implicit_gemm_conv_2d_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn, "_channel_", n_channel_specialization ? std::to_string(n_channel_specialization) : "l", "_filter_", small_filter ? 's' : 'l'); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_conv_kernel( d, kname, out, bm, bn, bk, wm, wn, n_channel_specialization, small_filter); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions int tile = 1 << swizzle_log; size_t grid_dim_y = (tm + tile - 1) / tile; size_t grid_dim_x = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups); // Encode arrays compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); // Encode params compute_encoder.set_bytes(conv_params, 3); compute_encoder.set_bytes(gemm_params, 4); // Launch kernel compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void implicit_gemm_conv_2D_general_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<2>& conv_params) { // Deduce implicit gemm size int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; int implicit_N = conv_params.O; int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; // Determine block and warp tiles int wm = 2, wn = 2; // Make jump params int f_wgt_jump_h = std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0]; int f_wgt_jump_w = std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1]; int f_out_jump_h = std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0]; int f_out_jump_w = std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1]; int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h; int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w; int adj_out_hw = adj_out_h * adj_out_w; int adj_implicit_m = conv_params.N * adj_out_hw; Conv2DGeneralJumpParams jump_params{ /* const int f_wgt_jump_h = */ f_wgt_jump_h, /* const int f_wgt_jump_w = */ f_wgt_jump_w, /* const int f_out_jump_h = */ f_out_jump_h, /* const int f_out_jump_w = */ f_out_jump_w, /* const int adj_out_h = */ adj_out_h, /* const int adj_out_w = */ adj_out_w, /* const int adj_out_hw = */ adj_out_hw, /* const int adj_implicit_m = */ adj_implicit_m}; // Make base info std::vector base_h(f_out_jump_h); std::vector base_w(f_out_jump_w); int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0]; int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1]; int init_h = (conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0); int init_w = (conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0); for (int i = 0; i < f_out_jump_h; ++i) { int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h; int wh_base = 0; while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) { wh_base++; ih_loop += jump_h; } int wh_size = ((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h; base_h[i] = {wh_base, wh_size}; } for (int j = 0; j < f_out_jump_w; ++j) { int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w; int ww_base = 0; while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) { ww_base++; iw_loop += jump_w; } int ww_size = ((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w; base_w[j] = {ww_base, ww_size}; } // Collect block sizes int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32; int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32; int bk = 16; int tn = (implicit_N + bn - 1) / bn; int tm = (adj_implicit_m + bm - 1) / bm; int swizzle_log = 0; // Get channel iteration info int channel_k_iters = ((conv_params.C + bk - 1) / bk); int gemm_k_iters = channel_k_iters; bool align_C = conv_params.C % bk == 0; // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); int ijw = conv_params.in_strides[2] * conv_params.kdil[1]; int ijh = conv_params.in_strides[1] * conv_params.kdil[0]; int inp_jump_w = sign * ijw; int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw); int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh - sign * (conv_params.wS[1] - 1) * ijw; // Build implicit gemm params ImplicitGemmConv2DParams gemm_params{ /* const int M = */ implicit_M, /* const int N = */ implicit_N, /* const int K = */ implicit_K, /* const int gemm_k_iterations = */ gemm_k_iters, /* const int inp_jump_w = */ inp_jump_w, /* const int inp_jump_h = */ inp_jump_h, /* const int inp_jump_c = */ inp_jump_c, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int swizzle_log = */ swizzle_log}; // Determine kernel std::string kname; kname.reserve(64); concatenate( kname, "implicit_gemm_conv_2d_general_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); std::string hash_name; hash_name.reserve(64); concatenate(hash_name, kname, "_alC_", align_C); metal::MTLFCList func_consts = { {&align_C, MTL::DataType::DataTypeBool, 200}, }; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_conv_general_kernel( d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions int tile = 1 << swizzle_log; size_t grid_dim_y = (tm + tile - 1) / tile; size_t grid_dim_x = tn * tile; size_t grid_dim_z = f_out_jump_h * f_out_jump_w; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); // Encode arrays compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); // Encode params compute_encoder.set_bytes(conv_params, 3); compute_encoder.set_bytes(gemm_params, 4); compute_encoder.set_bytes(jump_params, 5); compute_encoder.set_vector_bytes(base_h, 6); compute_encoder.set_vector_bytes(base_w, 7); // Launch kernel compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void implicit_gemm_conv_3D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<3>& conv_params) { const int groups = conv_params.groups; const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2]; const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * C_per_group; // Determine block and warp tiles int wm = 2, wn = 2; int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32; int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32; int bk = 16; if (implicit_N <= 16) { bn = 8; wm = 4; wn = 1; } int tn = (implicit_N + bn - 1) / bn; int tm = (implicit_M + bm - 1) / bm; int swizzle_log = 0; bool small_filter = (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16 && conv_params.wS[2] <= 16); int channel_k_iters = ((C_per_group + bk - 1) / bk); int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * channel_k_iters; // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); int ijw = conv_params.in_strides[3] * conv_params.kdil[2]; int ijh = conv_params.in_strides[2] * conv_params.kdil[1]; int ijd = conv_params.in_strides[1] * conv_params.kdil[0]; int inp_jump_w = sign * ijw; int inp_jump_h = sign * (ijh - (conv_params.wS[2] - 1) * ijw); int inp_jump_d = sign * (ijd - (conv_params.wS[1] - 1) * ijh - (conv_params.wS[2] - 1) * ijw); int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijd - sign * (conv_params.wS[1] - 1) * ijh - sign * (conv_params.wS[2] - 1) * ijw; // Build implicit gemm params ImplicitGemmConv3DParams gemm_params{ /* const int M = */ implicit_M, /* const int N = */ implicit_N, /* const int K = */ implicit_K, /* const int gemm_k_iterations = */ gemm_k_iters, /* const int inp_jump_w = */ inp_jump_w, /* const int inp_jump_h = */ inp_jump_h, /* const int inp_jump_d = */ inp_jump_d, /* const int inp_jump_c = */ inp_jump_c, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int swizzle_log = */ swizzle_log}; // Determine kernel std::string kname; kname.reserve(64); concatenate( kname, "implicit_gemm_conv_3d_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn, "_filter_", small_filter ? 's' : 'l'); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_conv_3d_kernel(d, kname, out, bm, bn, bk, wm, wn, small_filter); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions int tile = 1 << swizzle_log; size_t grid_dim_y = (tm + tile - 1) / tile; size_t grid_dim_x = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups); // Encode arrays compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); // Encode params compute_encoder.set_bytes(conv_params, 3); compute_encoder.set_bytes(gemm_params, 4); // Launch kernel compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void pad_and_slice_conv_3D_gpu( const Stream& s, metal::Device& d, const array& in_pre, const array& wt_pre, array& out, const MLXConvParams<3>& conv_params) { // For now assume conv_params.groups == 1 int extra_c = ((conv_params.C + 15) / 16) * 16 - conv_params.C; int extra_o = ((conv_params.O + 15) / 16) * 16 - conv_params.O; // Pad function auto pad_array = [&](const array& x, int pad_ax_first, int pad_ax_last) { if (pad_ax_first == 0 && pad_ax_last == 0) { return ensure_row_contiguous(x, d, s); } auto xshape = x.shape(); xshape.front() += pad_ax_first; xshape.back() += pad_ax_last; array x_copy(xshape, x.dtype(), nullptr, {}); array zero(0, x.dtype()); pad_gpu(x, zero, x_copy, {0, -1}, {0, 0}, s); d.add_temporary(x_copy, s.index); return x_copy; }; // Allocate space for the intermediate output. Don't save it as a temporary // since it will be sliced to the output so they share the buffer. auto oshape = out.shape(); oshape.back() += extra_o; array intermediate(oshape, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); // Actually pad and conv array in = pad_array(in_pre, 0, extra_c); array wt = pad_array(wt_pre, extra_o, extra_c); auto new_params = MLXConvParams<3>::with_padded_channels(conv_params, extra_o, extra_c); implicit_gemm_conv_3D_gpu(s, d, in, wt, intermediate, new_params); // Slice out out.copy_shared_buffer( intermediate, intermediate.strides(), {0}, intermediate.data_size()); } void dispatch_conv_3D_gpu( const Stream& s, metal::Device& d, const array& in_pre, const array& wt_pre, array& out, const MLXConvParams<3>& conv_params, std::vector& copies) { bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1 && conv_params.idil[2] == 1; const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; bool mod16_channels = C_per_group % 16 == 0 && (O_per_group <= 16 || O_per_group % 16 == 0); // Check if we can do implicit gemm but the channels are not divisible by 16 // so we can pad and slice. // // We check it first because it doesn't need contiguous inputs and it needs // different output allocation. if (is_idil_one && !mod16_channels && conv_params.groups == 1) { return pad_and_slice_conv_3D_gpu(s, d, in_pre, wt_pre, out, conv_params); } // Allocate the output and ensure contiguous inputs out.set_data(allocator::malloc(out.nbytes())); auto in = ensure_row_contiguous(in_pre, d, s); auto wt = ensure_row_contiguous(wt_pre, d, s); // Perform the implicit gemm if (is_idil_one && mod16_channels) { return implicit_gemm_conv_3D_gpu(s, d, in, wt, out, conv_params); } // Explicit gemms where we unfold and do a matmul // (separate one for groups > 1) if (conv_params.groups > 1) { return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); } return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } void winograd_conv_2D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<2>& conv_params, std::vector& copies_w) { Shape padded_shape = { conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.iS[1] + 2 * conv_params.pad[1], conv_params.C}; padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; array in_padded(std::move(padded_shape), in.dtype(), nullptr, {}); // Fill with zeros array zero_arr = array(0, in.dtype()); fill_gpu(zero_arr, in_padded, s); copies_w.push_back(zero_arr); // Pick input slice from padded size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + conv_params.pad[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); copies_w.push_back(in_padded_slice); copies_w.push_back(in_padded); MLXConvParams<2> conv_params_updated{ /* const int N = */ static_cast(in_padded.shape(0)), /* const int C = */ static_cast(in_padded.shape(3)), /* const int O = */ static_cast(wt.shape(0)), /* const int iS[NDIM] = */ {static_cast(in_padded.shape(1)), static_cast(in_padded.shape(2))}, /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, /* const int oS[NDIM] = */ {static_cast(out.shape(1)), static_cast(out.shape(2))}, /* const int str[NDIM] = */ {1, 1}, /* const int pad[NDIM] = */ {0, 0}, /* const int kdil[NDIM] = */ {1, 1}, /* const int idil[NDIM] = */ {1, 1}, /* const size_t in_strides[NDIM + 2] = */ {in_padded.strides()[0], in_padded.strides()[1], in_padded.strides()[2], in_padded.strides()[3]}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, /* const int groups = */ 1, /* const bool flip = */ false, }; int O_c = conv_params.O; int C_c = conv_params.C; int N_tiles_n = conv_params.N; int N_tiles_h = (conv_params.oS[0] + 5) / 6; int N_tiles_w = (conv_params.oS[1] + 5) / 6; int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; // Do filter transform Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); filt_wg.set_data(allocator::malloc(filt_wg.nbytes())); copies_w.push_back(filt_wg); { int bc = 32; int bo = 4; std::string kname; kname.reserve(32); concatenate( kname, "winograd_conv_2d_weight_transform_", type_to_name(out), "_bc", bc); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(wt, 0); compute_encoder.set_output_array(filt_wg, 1); compute_encoder.set_bytes(C_c, 2); compute_encoder.set_bytes(O_c, 3); MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do input transform Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); inp_wg.set_data(allocator::malloc(inp_wg.nbytes())); copies_w.push_back(inp_wg); { int bc = 32; int wm = 2; int wn = 2; std::string kname; kname.reserve(32); concatenate( kname, "winograd_conv_2d_input_transform_", type_to_name(out), "_bc", bc); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_output_array(inp_wg, 1); compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do batched gemm Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); out_wg.set_data(allocator::malloc(out_wg.nbytes())); copies_w.push_back(out_wg); { std::vector empty_copies; steel_matmul( s, d, /*a = */ inp_wg, /*b = */ filt_wg, /*c = */ out_wg, /*M = */ N_tiles, /*N = */ conv_params.O, /*K = */ conv_params.C, /*batch_size_out = */ 8 * 8, /*a_cols = */ conv_params.C, /*b_cols = */ conv_params.O, /*a_transposed = */ false, /*b_transposed = */ false, /*copies = */ empty_copies); } // Do output transform { int bc = 32; int wm = 2; int wn = 2; std::string kname; kname.reserve(32); concatenate( kname, "winograd_conv_2d_output_transform_", type_to_name(out), "_bo", bc); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } void depthwise_conv_2D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<2>& conv_params) { std::string base_name; base_name.reserve(32); concatenate(base_name, "depthwise_conv_2d_", type_to_name(out)); const int N = conv_params.N; const int ker_h = conv_params.wS[0]; const int ker_w = conv_params.wS[1]; const int str_h = conv_params.str[0]; const int str_w = conv_params.str[1]; const int tc = 8; const int tw = 8; const int th = 4; const bool do_flip = conv_params.flip; metal::MTLFCList func_consts = { {&ker_h, MTL::DataType::DataTypeInt, 00}, {&ker_w, MTL::DataType::DataTypeInt, 01}, {&str_h, MTL::DataType::DataTypeInt, 10}, {&str_w, MTL::DataType::DataTypeInt, 11}, {&th, MTL::DataType::DataTypeInt, 100}, {&tw, MTL::DataType::DataTypeInt, 101}, {&do_flip, MTL::DataType::DataTypeBool, 200}, }; // clang-format off std::string hash_name; hash_name.reserve(64); concatenate( hash_name, base_name, "_ker_h_", ker_h, "_ker_w_", ker_w, "_str_h_", str_h, "_str_w_", str_w, "_tgp_h_", th, "_tgp_w_", tw, "_do_flip_", do_flip ? 't' : 'n'); // clang-format on auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); compute_encoder.set_bytes(conv_params, 3); MTL::Size group_dims = MTL::Size(tc, tw, th); MTL::Size grid_dims = MTL::Size( conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void dispatch_conv_2D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out, const MLXConvParams<2>& conv_params, std::vector& copies) { bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; if (is_idil_one && conv_params.groups > 1) { const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && conv_params.wt_strides[1] == conv_params.wS[1] && conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); } if ((C_per_group <= 4 || C_per_group % 16 == 0) && (O_per_group <= 16 || O_per_group % 16 == 0)) { return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } else { return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); } } // Direct to winograd conv bool inp_large = (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; bool channels_large = (conv_params.C + conv_params.O) >= 256; bool out_large = (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && channels_large) { return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } // Direct to implicit gemm conv if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && (conv_params.O <= 16 || conv_params.O % 16 == 0)) { return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) { return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); } // Direct to explicit gemm conv else { return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } } void depthwise_conv_1D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array& out) { bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX; std::string base_name; base_name.reserve(32); concatenate( base_name, "depthwise_conv_1d_", large ? "_large" : "", type_to_name(out)); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(base_name); compute_encoder.set_compute_pipeline_state(kernel); auto B = in.shape(0); auto Tout = out.shape(1); auto D = in.shape(2); auto K = wt.shape(1); compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); if (large) { int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)}; compute_encoder.set_bytes(strides, 3, 3); } else { int strides[3] = { static_cast(in.strides(0)), static_cast(in.strides(1)), static_cast(in.strides(2))}; compute_encoder.set_bytes(strides, 3, 3); } compute_encoder.set_bytes(K, 4); auto group_dims = get_block_dims(D, Tout, B); MTL::Size grid_dims = MTL::Size(D, Tout, B); compute_encoder.dispatch_threads(grid_dims, group_dims); } void conv_1D_gpu( const Stream& s, metal::Device& d, const array& in_pre, const array& wt_pre, array& out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, int groups, bool flip, std::vector& copies) { // Allocate space and ensure weights are contiguous out.set_data(allocator::malloc(out.nbytes())); auto in = ensure_row_contiguous(in_pre, d, s); auto wt = ensure_row_contiguous(wt_pre, d, s); bool is_idil_one = in_dilation[0] == 1; int C = in.shape(2); int O = wt.shape(0); // Fast path for fully separable 1D convolution if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 && wt_dilation[0] == 1 && padding[0] == 0 && !flip) { depthwise_conv_1D_gpu(s, d, in, wt, out); return; } const int C_per_group = C / groups; const int O_per_group = O / groups; // Direct to implicit gemm conv if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && (O_per_group <= 16 || O_per_group % 16 == 0)) { MLXConvParams<2> conv_params{ /* const int N = */ static_cast(in.shape(0)), /* const int C = */ C, /* const int O = */ O, /* const int iS[NDIM] = */ {static_cast(in.shape(1)), 1}, /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), 1}, /* const int oS[NDIM] = */ {static_cast(out.shape(1)), 1}, /* const int str[NDIM] = */ {wt_strides[0], 1}, /* const int pad[NDIM] = */ {padding[0], 0}, /* const int kdil[NDIM] = */ {wt_dilation[0], 1}, /* const int idil[NDIM] = */ {in_dilation[0], 1}, /* const size_t in_strides[NDIM + 2] = */ {in.strides()[0], in.strides()[1], 0, in.strides()[2]}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], 0, out.strides()[2]}, /* const int groups = */ groups, /* const bool flip = */ flip}; dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); return; } // Make conv params MLXConvParams<1> conv_params{ /* const int N = */ static_cast(in.shape(0)), /* const int C = */ static_cast(in.shape(2)), /* const int O = */ static_cast(wt.shape(0)), /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, /* const int str[NDIM] = */ {wt_strides[0]}, /* const int pad[NDIM] = */ {padding[0]}, /* const int kdil[NDIM] = */ {wt_dilation[0]}, /* const int idil[NDIM] = */ {in_dilation[0]}, /* const size_t in_strides[NDIM + 2] = */ {in.strides()[0], in.strides()[1], in.strides()[2]}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2]}, /* const int groups = */ groups, /* const bool flip = */ flip}; // Direct to explicit gemm conv if (groups > 1) { return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); } else { return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } } void conv_2D_gpu( const Stream& s, metal::Device& d, const array& in_pre, const array& wt_pre, array& out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, const int groups, bool flip, std::vector& copies) { // Allocate space and ensure weights are contiguous out.set_data(allocator::malloc(out.nbytes())); auto in = ensure_row_contiguous(in_pre, d, s); auto wt = ensure_row_contiguous(wt_pre, d, s); // Make conv params MLXConvParams<2> conv_params{ /* const int N = */ static_cast(in.shape(0)), /* const int C = */ static_cast(in.shape(3)), /* const int O = */ static_cast(wt.shape(0)), /* const int iS[NDIM] = */ {static_cast(in.shape(1)), static_cast(in.shape(2))}, /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, /* const int oS[NDIM] = */ {static_cast(out.shape(1)), static_cast(out.shape(2))}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]}, /* const int pad[NDIM] = */ {padding[0], padding[1]}, /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]}, /* const size_t in_strides[NDIM + 2] = */ {in.strides(0), in.strides(1), in.strides(2), in.strides(3)}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)}, /* const size_t out_strides[NDIM + 2] = */ {out.strides(0), out.strides(1), out.strides(2), out.strides(3)}, /* const int groups = */ groups, /* const bool flip = */ flip, }; dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } void conv_3D_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, int groups, bool flip, std::vector& copies) { // We will use the contiguous strides for the conv params because that is // what the rest of the code expects. constexpr int NDIM = 3; int64_t in_arr_strides[NDIM + 2]; int64_t wt_arr_strides[NDIM + 2]; in_arr_strides[NDIM + 1] = wt_arr_strides[NDIM + 1] = 1; for (int i = NDIM; i >= 0; i--) { in_arr_strides[i] = in_arr_strides[i + 1] * in.shape(i + 1); wt_arr_strides[i] = wt_arr_strides[i + 1] * wt.shape(i + 1); } // Make conv params MLXConvParams<3> conv_params{ /* const int N = */ static_cast(in.shape(0)), /* const int C = */ static_cast(in.shape(4)), /* const int O = */ static_cast(wt.shape(0)), /* const int iS[NDIM] = */ {static_cast(in.shape(1)), static_cast(in.shape(2)), static_cast(in.shape(3))}, /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), static_cast(wt.shape(2)), static_cast(wt.shape(3))}, /* const int oS[NDIM] = */ {static_cast(out.shape(1)), static_cast(out.shape(2)), static_cast(out.shape(3))}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]}, /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]}, /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1], wt_dilation[2]}, /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1], in_dilation[2]}, /* const size_t in_strides[NDIM + 2] = */ {in_arr_strides[0], in_arr_strides[1], in_arr_strides[2], in_arr_strides[3], in_arr_strides[4]}, /* const size_t wt_strides[NDIM + 2] = */ {wt_arr_strides[0], wt_arr_strides[1], wt_arr_strides[2], wt_arr_strides[3], wt_arr_strides[4]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides(0), out.strides(1), out.strides(2), out.strides(3), out.strides(4)}, /* const int groups = */ groups, /* const bool flip = */ flip, }; return dispatch_conv_3D_gpu(s, d, in, wt, out, conv_params, copies); } } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); // Intermediates that are put here will be added to the command encoder as // temporaries. std::vector copies; // Some shortcuts for brevity const array& in = inputs[0]; const array& wt = inputs[1]; // 3D conv if (out.ndim() == 5) { conv_3D_gpu( s, d, in, wt, out, padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, groups_, flip_, copies); } // 2D conv else if (out.ndim() == 4) { conv_2D_gpu( s, d, in, wt, out, padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, groups_, flip_, copies); } // 1D conv else if (out.ndim() == 3) { conv_1D_gpu( s, d, in, wt, out, padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, groups_, flip_, copies); } // Throw error else { throw std::invalid_argument( "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions."); } // Record copies if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/copy.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { bool donated = set_copy_output_data(in, out, ctype); if (donated && in.dtype() == out.dtype()) { // If the output has the same type as the input then there is nothing to // copy, just use the buffer. return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_gpu_inplace(in, out, ctype, s); } void copy_gpu_inplace( const array& in, array& out, const Shape& data_shape, const Strides& strides_in_pre, const Strides& strides_out_pre, int64_t inp_offset, int64_t out_offset, CopyType ctype, const Stream& s, std::optional dynamic_i_offset /* = std::nullopt */, std::optional dynamic_o_offset /* = std::nullopt */) { if (out.size() == 0) { return; } // Try to collapse contiguous dims auto maybe_collapse = [ctype, &data_shape, &strides_in_pre, &strides_out_pre]() { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { auto [shape, strides] = collapse_contiguous_dims( data_shape, std::vector{strides_in_pre, strides_out_pre}, /* size_cap = */ INT32_MAX); return std::make_tuple(shape, strides[0], strides[1]); } else { Strides e{}; return std::make_tuple(Shape{}, e, e); } }; auto [shape, strides_in_, strides_out_] = maybe_collapse(); int ndim = shape.size(); bool large; if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { // Allow for negative strides large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; } else { large = out.data_size() > UINT32_MAX; } bool dynamic = dynamic_i_offset || dynamic_o_offset; auto& d = metal::device(s.device); int work_per_thread = 1; std::string kernel_name; switch (ctype) { case CopyType::Scalar: kernel_name = large ? "s2" : "s"; break; case CopyType::Vector: kernel_name = large ? "v2" : "v"; break; case CopyType::General: kernel_name = "g"; break; case CopyType::GeneralGeneral: kernel_name = "gg"; break; } if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { kernel_name += std::to_string(shape.size()); } else { work_per_thread = large ? 4 : 2; concatenate(kernel_name, "n", std::to_string(work_per_thread)); } if (large) { kernel_name += "large"; } if (dynamic) { kernel_name += "_dynamic"; if (ctype != CopyType::GeneralGeneral) { throw std::runtime_error( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } } else { work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); if (!large && work_per_thread > 1) { kernel_name += "n"; } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) : get_copy_kernel(d, kernel_name, in, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); inp_offset *= size_of(in.dtype()); out_offset *= size_of(out.dtype()); compute_encoder.set_input_array(in, 0, inp_offset); compute_encoder.set_output_array(out, 1, out_offset); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { Strides strides_in{strides_in_.begin(), strides_in_.end()}; Strides strides_out{strides_out_.begin(), strides_out_.end()}; if (ndim > 3) { compute_encoder.set_vector_bytes(shape, ndim, 2); } compute_encoder.set_vector_bytes(strides_in, ndim, 3); if (ctype == CopyType::GeneralGeneral) { compute_encoder.set_vector_bytes(strides_out, ndim, 4); } size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t data_size = 1; for (auto& s : shape) data_size *= s; size_t rest = data_size / (dim0 * dim1); if (ndim > MAX_COPY_SPECIALIZED_DIMS) { compute_encoder.set_bytes(ndim, 5); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } if (dynamic) { if (dynamic_i_offset) { compute_encoder.set_input_array(*dynamic_i_offset, 6); } else { compute_encoder.set_bytes(0ll, 6); } if (dynamic_o_offset) { compute_encoder.set_input_array(*dynamic_o_offset, 7); } else { compute_encoder.set_bytes(0ll, 7); } } // NB assuming thread_group_size is a power of 2 larger than 32 x 32 if (thread_group_size != 1024) { throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { compute_encoder.set_bytes(out.data_size(), 2); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { compute_encoder.set_bytes(out.data_size(), 2); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims); } } void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; } out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; int work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); auto& d = metal::device(s.device); std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { compute_encoder.set_bytes(out.data_size(), 2); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { compute_encoder.set_bytes(out.data_size(), 2); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims); } void reshape_gpu(const array& in, array& out, Stream s) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { out.set_data(allocator::malloc(out.nbytes())); copy_gpu_inplace( in, out, in.shape(), in.strides(), make_contiguous_strides(in.shape()), 0, 0, CopyType::General, s); } else { shared_buffer_reshape(in, out_strides, out); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/custom_kernel.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/utils.h" namespace mlx::core::fast { struct CustomKernelCache { std::unordered_map libraries; }; static CustomKernelCache& cache() { static CustomKernelCache cache_; return cache_; }; std::string write_signature( std::string func_name, const std::string& header, const std::string& source, const std::vector& input_names, const std::vector& inputs, const std::vector& output_names, const std::vector& output_dtypes, const std::vector>& template_args, const std::vector& attributes, const std::vector>& shape_infos, bool atomic_outputs) { std::string kernel_source; kernel_source.reserve(header.size() + source.size() + 16384); kernel_source += header; // Auto-generate a function signature based on `template_args` // and the dtype/shape of the arrays passed as `inputs`. if (!template_args.empty()) { kernel_source += "template <"; int i = 0; for (const auto& [name, arg] : template_args) { std::string param_type; if (std::holds_alternative(arg)) { param_type = "int"; } else if (std::holds_alternative(arg)) { param_type = "bool"; } else if (std::holds_alternative(arg)) { param_type = "typename"; } if (i > 0) { kernel_source += ", "; } kernel_source += param_type; kernel_source += " "; kernel_source += name; i++; } kernel_source += ">\n"; } kernel_source += "[[kernel]] void "; kernel_source += func_name; kernel_source += "(\n"; int index = 0; constexpr int max_constant_array_size = 8; // Add inputs for (int i = 0; i < inputs.size(); ++i) { const auto& name = input_names[i]; const auto& arr = inputs[i]; auto dtype = get_type_string(arr.dtype()); std::string location = arr.size() < max_constant_array_size ? "constant" : "device"; std::string ref = arr.ndim() == 0 ? "&" : "*"; kernel_source += " const "; kernel_source += location; kernel_source += " "; kernel_source += dtype; kernel_source += ref; kernel_source += " "; kernel_source += name; kernel_source += " [[buffer("; kernel_source += std::to_string(index); kernel_source += ")]],\n"; index++; // Add input shape, strides and ndim if present in the source if (arr.ndim() > 0) { if (std::get<0>(shape_infos[i])) { kernel_source += (" const constant int* " + name + "_shape [[buffer(" + std::to_string(index) + ")]],\n"); index++; } if (std::get<1>(shape_infos[i])) { kernel_source += (" const constant int64_t* " + name + "_strides [[buffer(" + std::to_string(index) + ")]],\n"); index++; } if (std::get<2>(shape_infos[i])) { kernel_source += (" const constant int& " + name + "_ndim [[buffer(" + std::to_string(index) + ")]],\n"); index++; } } } // Add outputs for (int i = 0; i < output_names.size(); ++i) { const auto& name = output_names[i]; const auto& dtype = output_dtypes[i]; kernel_source += " device "; auto type_string = get_type_string(dtype); if (atomic_outputs) { kernel_source += "atomic<"; } kernel_source += type_string; if (atomic_outputs) { kernel_source += ">"; } kernel_source += "* "; kernel_source += name; kernel_source += " [[buffer("; kernel_source += std::to_string(index); kernel_source += ")]]"; if (index < inputs.size() + output_names.size() - 1 || attributes.size() > 0) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; } index++; } index = 0; for (const auto& attr : attributes) { kernel_source += attr; if (index < attributes.size() - 1) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; } index++; } kernel_source += source; kernel_source += "\n}\n"; return kernel_source; } std::string write_template( const std::vector>& template_args) { std::ostringstream template_def; template_def << "<"; int i = 0; for (const auto& [name, arg] : template_args) { if (i > 0) { template_def << ", "; } if (std::holds_alternative(arg)) { template_def << std::get(arg); } else if (std::holds_alternative(arg)) { template_def << std::get(arg); } else if (std::holds_alternative(arg)) { template_def << get_type_string(std::get(arg)); } i++; } template_def << ">"; return template_def.str(); } CustomKernelFunction metal_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, const std::string& source, const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, bool atomic_outputs /* = false */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); } std::vector> shape_infos; for (auto& n : input_names) { std::tuple shape_info; std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos; std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos; std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos; shape_infos.push_back(shape_info); } const std::vector> metal_attributes = { {"dispatch_quadgroups_per_threadgroup", "uint"}, {"dispatch_simdgroups_per_threadgroup", "uint"}, {"dispatch_threads_per_threadgroup", "uint3"}, {"grid_origin", "uint3"}, {"grid_size", "uint3"}, {"quadgroup_index_in_threadgroup", "uint"}, {"quadgroups_per_threadgroup", "uint"}, {"simdgroup_index_in_threadgroup", "uint"}, {"simdgroups_per_threadgroup", "uint"}, {"thread_execution_width", "uint"}, {"thread_index_in_quadgroup", "uint"}, {"thread_index_in_simdgroup", "uint"}, {"thread_index_in_threadgroup", "uint"}, {"thread_position_in_grid", "uint3"}, {"thread_position_in_threadgroup", "uint3"}, {"threadgroup_position_in_grid", "uint3"}, {"threadgroups_per_grid", "uint3"}, {"threads_per_grid", "uint3"}, {"threads_per_simdgroup", "uint"}, {"threads_per_threadgroup", "uint3"}, }; std::vector attributes; for (const auto& [attr, dtype] : metal_attributes) { if (source.find(attr) != std::string::npos) { attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); } } return [=, shape_infos = std::move(shape_infos), attributes = std::move(attributes)]( const std::vector& inputs, const std::vector& output_shapes, const std::vector& output_dtypes, std::tuple grid, std::tuple threadgroup, const std::vector>& template_args = {}, std::optional init_value = std::nullopt, bool verbose = false, StreamOrDevice s_ = {}) { if (inputs.size() != input_names.size()) { std::ostringstream msg; msg << "[metal_kernel] Expected `inputs` to have size " << input_names.size() << " but got size " << inputs.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_shapes.size() != output_names.size()) { std::ostringstream msg; msg << "[metal_kernel] Expected `output_shapes` to have size " << output_names.size() << " but got size " << output_shapes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_dtypes.size() != output_names.size()) { std::ostringstream msg; msg << "[metal_kernel] Expected `output_dtypes` to have size " << output_names.size() << " but got size " << output_dtypes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } auto s = to_stream(s_); if (s.device != Device::gpu) { throw std::invalid_argument("[metal_kernel] Only supports the GPU."); } std::string kernel_name = "custom_kernel_" + name; std::string template_def = ""; if (!template_args.empty()) { std::regex disallowed_chars("\\<|\\>|(, )"); template_def = write_template(template_args); auto template_hash = std::regex_replace(template_def, disallowed_chars, "_"); template_hash.pop_back(); kernel_name += "_"; kernel_name += template_hash; } std::string kernel_source = write_signature( kernel_name, header, source, input_names, inputs, output_names, output_dtypes, template_args, attributes, shape_infos, atomic_outputs); if (!template_args.empty()) { template_def = kernel_name + template_def; kernel_source += "\ntemplate [[host_name(\""; kernel_source += kernel_name; kernel_source += "\")]] [[kernel]] decltype("; kernel_source += template_def; kernel_source += ") "; kernel_source += template_def; kernel_source += ";\n"; } if (verbose) { std::cout << "Generated source code for `" << name << "`:" << std::endl << "```" << std::endl << kernel_source << std::endl << "```" << std::endl; } return array::make_arrays( std::move(output_shapes), std::move(output_dtypes), std::make_shared( s, std::move(kernel_name), std::move(kernel_source), grid, threadgroup, shape_infos, ensure_row_contiguous, init_value, std::vector{}, false, 0), std::move(inputs)); }; } void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { // silence some warnings (void)is_precompiled_; (void)shared_memory_; auto& s = stream(); std::vector copies; for (auto& out : outputs) { if (init_value_) { copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { out.set_data(allocator::malloc(out.nbytes())); } } auto check_input = [&copies, &s, this](const array& x) -> const array { bool no_copy = x.flags().row_contiguous; if (!ensure_row_contiguous_ || no_copy) { return x; } else { copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); copy_gpu(x, copies.back(), CopyType::General, s); return copies.back(); } }; std::vector checked_inputs; for (const array& in : inputs) { checked_inputs.push_back(check_input(in)); } auto& d = metal::device(s.device); { // Clear kernels from the device library cache if needed auto& kernel_cache = cache(); if (auto it = kernel_cache.libraries.find(name_); it != kernel_cache.libraries.end()) { if (it->second != source_) { auto& d = metal::device(s.device); d.clear_library(name_); it->second = source_; } } else { kernel_cache.libraries.emplace(name_, source_); } } auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int index = 0; for (int i = 0; i < checked_inputs.size(); i++) { const array& in = checked_inputs[i]; auto& shape_info = shape_infos_[i]; compute_encoder.set_input_array(in, index); index++; if (in.ndim() > 0) { int ndim = in.ndim(); if (std::get<0>(shape_info)) { compute_encoder.set_vector_bytes(in.shape(), ndim, index); index++; } if (std::get<1>(shape_info)) { compute_encoder.set_vector_bytes(in.strides(), ndim, index); index++; } if (std::get<2>(shape_info)) { compute_encoder.set_bytes(ndim, index); index++; } } } for (auto& out : outputs) { compute_encoder.set_output_array(out, index); index++; } const auto [tx, ty, tz] = threadgroup_; auto tg_size = tx * ty * tz; auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); if (tg_size > max_tg_size) { std::ostringstream msg; msg << "Thread group size (" << tg_size << ") is greater than " << " the maximum allowed threads per threadgroup (" << max_tg_size << ")."; throw std::invalid_argument(msg.str()); } const auto [gx, gy, gz] = grid_; MTL::Size group_dims = MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); MTL::Size grid_dims = MTL::Size(gx, gy, gz); compute_encoder.dispatch_threads(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast ================================================ FILE: mlx/backend/metal/device.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" namespace std { // Required for putting the pointer in unordered_set. template struct hash> { size_t operator()(const NS::SharedPtr& p) const { return std::hash{}(p.get()); } }; } // namespace std namespace mlx::core::metal { namespace { constexpr const char* default_mtllib_path = METAL_PATH; auto get_metal_version() { auto get_metal_version_ = []() { if (__builtin_available(macOS 26, iOS 26, tvOS 26, visionOS 26, *)) { return MTL::LanguageVersion4_0; } else if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { return MTL::LanguageVersion3_2; } else { return MTL::LanguageVersion3_1; } }; static auto metal_version_ = get_metal_version_(); return metal_version_; } auto load_device() { auto devices = MTL::CopyAllDevices(); auto device = static_cast(devices->object(0)) ?: MTL::CreateSystemDefaultDevice(); if (!device) { throw std::runtime_error("Failed to load device"); } return device; } std::pair load_library_from_path( MTL::Device* device, const char* path) { auto library = NS::String::string(path, NS::UTF8StringEncoding); NS::Error* error; auto lib = device->newLibrary(library, &error); return std::make_pair(lib, error); } #ifdef SWIFTPM_BUNDLE MTL::Library* try_load_bundle( MTL::Device* device, NS::URL* url, const std::string& lib_name) { std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + SWIFTPM_BUNDLE + ".bundle"; auto bundle = NS::Bundle::alloc()->init( NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding)); if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + lib_name + ".metallib"; auto [lib, error] = load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } } return nullptr; } MTL::Library* try_load_framework( MTL::Device* device, NS::URL* url, const std::string& lib_name) { std::string resource_path = std::string(url->fileSystemRepresentation()) + "/" + lib_name + ".metallib"; auto [lib, error] = load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } return nullptr; } #endif // Firstly, search for the metallib in the same path as this binary std::pair load_colocated_library( MTL::Device* device, const std::string& relative_path) { auto path = current_binary_dir() / relative_path; if (!path.has_extension()) { path.replace_extension(".metallib"); } return load_library_from_path(device, path.c_str()); } std::pair load_swiftpm_library( MTL::Device* device, const std::string& lib_name) { #ifdef SWIFTPM_BUNDLE MTL::Library* library = try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } auto bundles = NS::Bundle::allBundles(); for (int i = 0, c = (int)bundles->count(); i < c; i++) { auto bundle = reinterpret_cast(bundles->object(i)); library = try_load_bundle(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } } // if SWIFTPM_BUNDLE is a framework identifier, try loading from that auto frameworks = NS::Bundle::allFrameworks(); for (int i = 0, c = (int)frameworks->count(); i < c; i++) { const auto bundle = reinterpret_cast(frameworks->object(i)); const auto identifier = bundle->bundleIdentifier(); if (identifier != nullptr && !strcmp(identifier->utf8String(), SWIFTPM_BUNDLE)) { library = try_load_framework(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } } } #endif return {nullptr, nullptr}; } MTL::Library* load_default_library(MTL::Device* device) { NS::Error* error[5]; MTL::Library* lib; // First try the colocated mlx.metallib std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); if (lib) { return lib; } std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx"); if (lib) { return lib; } // Then try default.metallib in a SwiftPM bundle if we have one std::tie(lib, error[2]) = load_swiftpm_library(device, "default"); if (lib) { return lib; } // Try lo load resources from Framework resources if SwiftPM wrapped as a // dynamic framework. std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default"); if (lib) { return lib; } // Finally try default_mtllib_path std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path); if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; for (int i = 0; i < 5; i++) { if (error[i] != nullptr) { msg << error[i]->localizedDescription()->utf8String() << " "; } } throw std::runtime_error(msg.str()); } return lib; } MTL::Library* load_library( MTL::Device* device, const std::string& lib_name, const std::string& lib_path) { // We have been given a path that ends in metallib so try to load it if (lib_path.size() > 9 && std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) { auto [lib, error] = load_library_from_path(device, lib_path.c_str()); if (!lib) { std::ostringstream msg; msg << "Failed to load the metallib from <" << lib_path << "> with error " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } return lib; } // We have been given a path so try to load from lib_path / lib_name.metallib if (lib_path.size() > 0) { std::string full_path = lib_path + "/" + lib_name + ".metallib"; auto [lib, error] = load_library_from_path(device, full_path.c_str()); if (!lib) { std::ostringstream msg; msg << "Failed to load the metallib from <" << full_path << "> with error " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } return lib; } // Try to load the colocated library { auto [lib, error] = load_colocated_library(device, lib_name); if (lib) { return lib; } } // Try to load the library from swiftpm { auto [lib, error] = load_swiftpm_library(device, lib_name); if (lib) { return lib; } } std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " << "We attempted to load it from <" << current_binary_dir() << "/" << lib_name << ".metallib>"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; #endif throw std::runtime_error(msg.str()); } } // namespace CommandEncoder::CommandEncoder( Device& d, int index, const MTL::ResidencySet* residency_set) : device_(d) { auto pool = new_scoped_memory_pool(); queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue()); if (!queue_) { throw std::runtime_error( "[metal::CommandEncoder] Failed to make new command queue."); } if (residency_set) { queue_->addResidencySet(residency_set); } debug_set_stream_queue_label(queue_.get(), index); buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); } void CommandEncoder::set_buffer( const MTL::Buffer* buf, int idx, int64_t offset /* = 0 */) { // Record as both input and output to ensure synchronization between command // buffers all_inputs_.insert((void*)buf); all_outputs_.insert((void*)buf); get_command_encoder()->setBuffer(buf, offset, idx); } void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { if (all_inputs_.insert(a.buffer().ptr()).second) { buffer_sizes_ += a.data_size(); } auto r_buf = static_cast(const_cast(a.buffer().ptr())); needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); auto a_buf = static_cast(a.buffer().ptr()); get_command_encoder()->setBuffer(a_buf, a.offset() + offset, idx); } void CommandEncoder::set_output_array( array& a, int idx, int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set set_input_array(a, idx, offset); register_output_array(a); } void CommandEncoder::register_output_array(const array& a) { all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(const_cast(a.buffer().ptr())); if (concurrent_) { concurrent_outputs_.insert(buf); } else { next_outputs_.insert(buf); } } void CommandEncoder::add_temporary(array arr) { temporaries_.push_back(std::move(arr)); } void CommandEncoder::add_temporaries(std::vector arrays) { temporaries_.insert( temporaries_.end(), std::make_move_iterator(arrays.begin()), std::make_move_iterator(arrays.end())); } void CommandEncoder::maybeInsertBarrier() { if (needs_barrier_) { get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers); needs_barrier_ = false; prev_outputs_ = std::move(next_outputs_); } else { prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end()); } next_outputs_.clear(); } void CommandEncoder::dispatch_threadgroups( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); buffer_ops_++; get_command_encoder()->dispatchThreadgroups(grid_dims, group_dims); } void CommandEncoder::dispatch_threads( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); buffer_ops_++; get_command_encoder()->dispatchThreads(grid_dims, group_dims); } void CommandEncoder::barrier() { get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers); } void CommandEncoder::end_encoding() { // Each command encoder has a unique fence. We also store a map of // all previous outputs of command encoders to their corresponding fence. // - The command encoder records its inputs and outputs. // - Wait on a fence if any inputs in the encoder are outputs of a previous // encoder. // - Update the map of outputs to include this command encoder's outputs. // - Always signal this command encoders fence. // - Add a completion handler for this command encoder that removes outputs // from the map to limit the growth of the map and avoid unnecessary waits // - Temporaries are a special case as they do not cross command encoder // boundaries. These can be removed early from the encoders inputs and // outputs since they don't need synchronization. if (!encoder_) { return; } // Remove temporaries from inputs and outputs. for (auto& t : temporaries_) { all_outputs_.erase(t.buffer().ptr()); all_inputs_.erase(t.buffer().ptr()); } // Keep references to the fences we waited on and put them in the completion // handler so they are not prematurely released. std::unordered_set> waiting_on; { std::lock_guard lk(outputs_mtx_); for (auto& in : all_inputs_) { if (auto it = prev_ce_outputs_.find(in); it != prev_ce_outputs_.end()) { // If we've already waited on a fence, don't wait on it again. if (waiting_on.find(it->second) == waiting_on.end()) { encoder_->waitForFence(it->second.get()); waiting_on.insert(it->second); } } } for (auto& out : all_outputs_) { prev_ce_outputs_[out] = fence_; } } encoder_->updateFence(fence_.get()); buffer_->addCompletedHandler([this, fence = std::move(fence_), temporaries = std::move(temporaries_), all_outputs = std::move(all_outputs_), waiting_on = std::move(waiting_on)]( MTL::CommandBuffer*) mutable { std::lock_guard lk(outputs_mtx_); for (auto& o : all_outputs) { if (auto it = prev_ce_outputs_.find(o); it != prev_ce_outputs_.end()) { if (it->second == fence) { prev_ce_outputs_.erase(it); } } } }); encoder_->endEncoding(); encoder_.reset(); needs_barrier_ = false; concurrent_ = false; prev_outputs_.clear(); next_outputs_.clear(); concurrent_outputs_.clear(); all_inputs_.clear(); } bool CommandEncoder::needs_commit() const { auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer(); return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb); } void CommandEncoder::commit() { buffer_->commit(); buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); buffer_ops_ = 0; buffer_sizes_ = 0; } MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { if (!encoder_) { encoder_ = NS::RetainPtr( buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent)); fence_ = NS::TransferPtr(device_.mtl_device()->newFence()); } return encoder_.get(); } Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); default_library_ = load_default_library(device_); arch_ = env::metal_gpu_arch(); if (arch_.empty()) { arch_ = std::string(device_->architecture()->name()->utf8String()); } int ag_tens = 0; int ag_ones = 0; if (arch_.size() >= 3) { ag_tens = arch_[arch_.size() - 3] - '0'; ag_ones = arch_[arch_.size() - 2] - '0'; ag_tens = (ag_tens < 10 && ag_tens >= 0) ? ag_tens : 0; ag_ones = (ag_ones < 10 && ag_ones >= 0) ? ag_ones : 0; } arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone max_ops_per_buffer_ = 20; max_mb_per_buffer_ = 40; break; case 'g': // base, pro max_ops_per_buffer_ = 40; max_mb_per_buffer_ = 40; break; case 's': // max max_ops_per_buffer_ = 50; max_mb_per_buffer_ = 50; break; case 'd': // ultra max_ops_per_buffer_ = 50; max_mb_per_buffer_ = 50; break; default: // default to medium max_ops_per_buffer_ = 40; max_mb_per_buffer_ = 40; break; } max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_); max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_); } Device::~Device() { auto pool = new_scoped_memory_pool(); for (auto& [l, kernel_map] : library_kernels_) { l->release(); for (auto& [_, k] : kernel_map) { k->release(); } } encoders_.clear(); device_->release(); } bool Device::command_buffer_needs_commit(int index) { return get_command_encoder(index).needs_commit(); } MTL::CommandBuffer* Device::get_command_buffer(int index) { return get_command_encoder(index).get_command_buffer(); } void Device::commit_command_buffer(int index) { get_command_encoder(index).commit(); } void Device::add_temporary(array arr, int index) { get_command_encoder(index).add_temporary(std::move(arr)); } void Device::add_temporaries(std::vector arrays, int index) { get_command_encoder(index).add_temporaries(std::move(arrays)); } void Device::end_encoding(int index) { get_command_encoder(index).end_encoding(); } CommandEncoder& Device::get_command_encoder(int index) { auto it = encoders_.find(index); if (it == encoders_.end()) { it = encoders_.try_emplace(index, *this, index, residency_set_).first; } return it->second; } MTL::Library* Device::get_library( const std::string& name, const std::string& path /* = "" */) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { return it->second; } } std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { return it->second; } auto new_lib = load_library(device_, name, path.c_str()); library_map_.insert({name, new_lib}); return new_lib; } MTL::Library* Device::build_library_(const std::string& source_string) { auto pool = new_scoped_memory_pool(); auto ns_code = NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding); NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init(); options->setFastMathEnabled(false); options->setLanguageVersion(get_metal_version()); #ifndef NDEBUG if (options->languageVersion() >= MTL::LanguageVersion3_2) { options->setEnableLogging(true); } #endif auto mtl_lib = device_->newLibrary(ns_code, options, &error); options->release(); // Throw error if unable to compile library if (!mtl_lib) { std::ostringstream msg; msg << "[metal::Device] Unable to build metal library from source\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } throw std::runtime_error(msg.str()); } return mtl_lib; } MTL::Function* Device::get_function_( const std::string& name, MTL::Library* mtl_lib) { // Pull kernel from library auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding); auto mtl_function = mtl_lib->newFunction(ns_name); return mtl_function; } MTL::Function* Device::get_function_( const std::string& name, const std::string& specialized_name, const MTLFCList& func_consts, MTL::Library* mtl_lib) { if (func_consts.empty() && (specialized_name == name)) { return get_function_(name, mtl_lib); } // Prepare function constants auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init(); for (auto [value, type, index] : func_consts) { mtl_func_consts->setConstantValue(value, type, index); } // Prepare function desc auto desc = MTL::FunctionDescriptor::functionDescriptor(); desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding)); desc->setSpecializedName( NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding)); desc->setConstantValues(mtl_func_consts); // Pull kernel from library NS::Error* error = nullptr; auto mtl_function = mtl_lib->newFunction(desc, &error); // Throw error if unable to build metal function if (!mtl_function) { std::ostringstream msg; msg << "[metal::Device] Unable to load function " << name << "\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } throw std::runtime_error(msg.str()); } mtl_func_consts->release(); return mtl_function; } MTL::ComputePipelineState* Device::get_kernel_( const std::string& name, const MTL::Function* mtl_function) { // Compile kernel to compute pipeline NS::Error* error = nullptr; MTL::ComputePipelineState* kernel; if (mtl_function) { kernel = device_->newComputePipelineState(mtl_function, &error); } // Throw error if unable to compile metal function if (!mtl_function || !kernel) { std::ostringstream msg; msg << "[metal::Device] Unable to load kernel " << name << "\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } throw std::runtime_error(msg.str()); } return kernel; } MTL::ComputePipelineState* Device::get_kernel_( const std::string& name, const MTL::Function* mtl_function, const MTL::LinkedFunctions* linked_functions) { // Check inputs if (!linked_functions) { return get_kernel_(name, mtl_function); } if (!mtl_function) { std::ostringstream msg; msg << "[metal::Device] Unable to load kernel " << name << "\n"; throw std::runtime_error(msg.str()); } // Prepare compute pipeline state descriptor auto desc = MTL::ComputePipelineDescriptor::alloc()->init(); desc->setComputeFunction(mtl_function); desc->setLinkedFunctions(linked_functions); // Compile kernel to compute pipeline NS::Error* error = nullptr; auto kernel = device_->newComputePipelineState( desc, MTL::PipelineOptionNone, nullptr, &error); // Throw error if unable to compile metal function if (!kernel) { std::ostringstream msg; msg << "[metal::Device] Unable to load kernel " << name << "\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } throw std::runtime_error(msg.str()); } return kernel; } MTL::Library* Device::get_library_(const std::string& name) { std::shared_lock lock(library_mtx_); auto it = library_map_.find(name); return (it != library_map_.end()) ? it->second : nullptr; } MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { return it->second; } } std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { return it->second; } auto mtl_lib = build_library_(builder()); library_map_.insert({name, mtl_lib}); return mtl_lib; } void Device::clear_library(const std::string& name) { std::unique_lock wlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { auto kernel_map_it = library_kernels_.find(it->second); for (auto& [_, kernel] : kernel_map_it->second) { kernel->release(); } library_kernels_.erase(kernel_map_it); it->second->release(); library_map_.erase(it); } } MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { return nullptr; } auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); std::vector objs(funcs.size()); for (int i = 0; i < funcs.size(); i++) { objs[i] = funcs[i]; } NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size()); lfuncs->setPrivateFunctions(funcs_arr); return lfuncs; } MTL::ComputePipelineState* Device::get_kernel_( const std::string& base_name, MTL::Library* mtl_lib, const std::string& hash_name, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { // Single writer allowed std::unique_lock wlock(kernel_mtx_); // Try loading again to avoid loading twice auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { return it->second; } auto pool = new_scoped_memory_pool(); // Pull kernel from library auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib); // Compile kernel to compute pipeline auto mtl_linked_funcs = get_linked_functions_(linked_functions); auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs); mtl_function->release(); mtl_linked_funcs->release(); // Add kernel to cache kernel_map_.insert({hash_name, kernel}); return kernel; } MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, MTL::Library* mtl_lib, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { const auto& kname = hash_name.empty() ? base_name : hash_name; { // Multiple readers allowed std::shared_lock lock(kernel_mtx_); // Look for cached kernel auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { return it->second; } } return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); } MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { return get_kernel( base_name, default_library_, hash_name, func_consts, linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { if (residency_set_ != nullptr) { throw std::runtime_error( "[Device::set_residency_set] Can only be set once."); } if (residency_set == nullptr) { return; } residency_set_ = residency_set; // Attach residency set to existing command queues for (auto& [_, encoder] : encoders_) { encoder.get_command_queue()->addResidencySet(residency_set_); } } Device& device(mlx::core::Device) { // Leak singleton device intentionally, to avoid cases where a compute kernel // returns and tries to access the object after it has been freed by the main // thread teardown. static Device* metal_device = new Device; return *metal_device; } std::unique_ptr> new_scoped_memory_pool() { auto dtor = [](void* ptr) { static_cast(ptr)->release(); }; return std::unique_ptr>( NS::AutoreleasePool::alloc()->init(), dtor); } } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/device.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include #include #include #include #include "mlx/array.h" #include "mlx/device.h" namespace mlx::core::metal { using MTLFCList = std::vector>; class Device; class MLX_API CommandEncoder { public: CommandEncoder(Device& d, int index, const MTL::ResidencySet* residency_set); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; struct ConcurrentContext { ConcurrentContext(CommandEncoder& enc) : enc(enc) { enc.concurrent_ = true; } ~ConcurrentContext() { enc.concurrent_ = false; enc.prev_outputs_.insert( enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end()); enc.concurrent_outputs_.clear(); } private: CommandEncoder& enc; }; void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0); void register_output_array(const array& a); void add_temporary(array arr); void add_temporaries(std::vector arrays); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void maybeInsertBarrier(); void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { get_command_encoder()->setComputePipelineState(kernel); } template >> void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { get_command_encoder()->setBytes( vec.data(), nelems * sizeof(typename Vec::value_type), idx); } template >> void set_vector_bytes(const Vec& vec, int idx) { return set_vector_bytes(vec, vec.size(), idx); } template void set_bytes(const T* v, int n, int idx) { return get_command_encoder()->setBytes(v, n * sizeof(T), idx); } template void set_bytes(const T& v, int idx) { return get_command_encoder()->setBytes(&v, sizeof(T), idx); } void set_threadgroup_memory_length(size_t length, int idx) { get_command_encoder()->setThreadgroupMemoryLength(length, idx); } ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } void barrier(); void end_encoding(); bool needs_commit() const; void commit(); MTL::CommandQueue* get_command_queue() const { return queue_.get(); } MTL::CommandBuffer* get_command_buffer() const { return buffer_.get(); } private: MTL::ComputeCommandEncoder* get_command_encoder(); Device& device_; // Buffer that stores encoded commands. NS::SharedPtr queue_; NS::SharedPtr buffer_; int buffer_ops_{0}; size_t buffer_sizes_{0}; // Encoder for issuing GPU commands. // The members are used within a single ComputeCommandEncoder and will be // reset after calling end_encoding(). NS::SharedPtr encoder_; NS::SharedPtr fence_; bool needs_barrier_{false}; bool concurrent_{false}; std::vector temporaries_; std::unordered_set prev_outputs_; std::unordered_set next_outputs_; std::unordered_set concurrent_outputs_; std::unordered_set all_inputs_; std::unordered_set all_outputs_; // A map of prior command encoder outputs to their corresponding fence. std::unordered_map> prev_ce_outputs_; std::mutex outputs_mtx_; }; class MLX_API Device { public: Device(); Device(const Device&) = delete; Device& operator=(const Device&) = delete; ~Device(); MTL::Device* mtl_device() { return device_; }; const std::string& get_architecture() const { return arch_; } int get_architecture_gen() const { return arch_gen_; } std::tuple get_max_ops_mb_per_buffer() const { return std::make_tuple(max_ops_per_buffer_, max_mb_per_buffer_); } MTL::CommandBuffer* get_command_buffer(int index); bool command_buffer_needs_commit(int index); void commit_command_buffer(int index); CommandEncoder& get_command_encoder(int index); void end_encoding(int index); MTL::Library* get_library( const std::string& name, const std::string& path = ""); MTL::Library* get_library( const std::string& name, const std::function& builder); void clear_library(const std::string& name); MTL::ComputePipelineState* get_kernel( const std::string& base_name, MTL::Library* mtl_lib, const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); MTL::ComputePipelineState* get_kernel( const std::string& base_name, const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); // Record temporary arrays for the given stream index void add_temporary(array arr, int index); void add_temporaries(std::vector arrays, int index); void set_residency_set(const MTL::ResidencySet* residency_set); private: MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_(const std::string& name); MTL::Library* build_library_(const std::string& source_string); MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); MTL::Function* get_function_( const std::string& name, const std::string& specialized_name, const MTLFCList& func_consts, MTL::Library* mtl_lib); MTL::LinkedFunctions* get_linked_functions_( const std::vector& funcs); MTL::ComputePipelineState* get_kernel_( const std::string& name, const MTL::Function* mtl_function); MTL::ComputePipelineState* get_kernel_( const std::string& name, const MTL::Function* mtl_function, const MTL::LinkedFunctions* linked_functions); MTL::ComputePipelineState* get_kernel_( const std::string& base_name, MTL::Library* mtl_lib, const std::string& hash_name, const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); MTL::Device* device_; std::unordered_map encoders_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; std::unordered_map library_map_; MTL::Library* default_library_; std::unordered_map< MTL::Library*, std::unordered_map> library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; MLX_API Device& device(mlx::core::Device); std::unique_ptr> new_scoped_memory_pool(); inline bool is_nax_available() { #ifdef MLX_METAL_NO_NAX return false; #else auto _check_nax = []() { bool can_use_nax = false; if (__builtin_available( macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { can_use_nax = true; } auto& d = metal::device(mlx::core::Device::gpu); auto arch = d.get_architecture().back(); auto gen = d.get_architecture_gen(); can_use_nax &= gen >= (arch == 'p' ? 18 : 17); return can_use_nax; }; static bool is_nax_available_ = _check_nax(); return is_nax_available_; #endif } } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/device_info.cpp ================================================ // Copyright © 2026 Apple Inc. #include #include "mlx/backend/gpu/device_info.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" namespace mlx::core::gpu { bool is_available() { return metal::is_available(); } int device_count() { return 1; } const std::unordered_map>& device_info(int device_index) { auto init_device_info = []() -> std::unordered_map> { auto pool = metal::new_scoped_memory_pool(); auto& device = metal::device(mlx::core::Device::gpu); auto raw_device = device.mtl_device(); auto name = std::string(raw_device->name()->utf8String()); auto arch = device.get_architecture(); size_t memsize = 0; size_t length = sizeof(memsize); sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); size_t rsrc_limit = 0; sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); if (rsrc_limit == 0) { rsrc_limit = 499000; } return { {"device_name", name}, {"architecture", arch}, {"max_buffer_length", raw_device->maxBufferLength()}, {"max_recommended_working_set_size", raw_device->recommendedMaxWorkingSetSize()}, {"memory_size", memsize}, {"resource_limit", rsrc_limit}}; }; static auto device_info_ = init_device_info(); static std::unordered_map> empty; if (device_index == 0) { return device_info_; } else { return empty; } } } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/metal/distributed.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/fence.h" #include "mlx/scheduler.h" namespace mlx::core::distributed { void AllReduce::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation."); } void AllGather::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation."); } void Send::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[Send::eval_gpu] has no GPU implementation."); } void Recv::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation."); } void ReduceScatter::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error( "[ReduceScatter::eval_gpu] has no GPU implementation."); } } // namespace mlx::core::distributed ================================================ FILE: mlx/backend/metal/eval.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" namespace mlx::core::gpu { void new_stream(Stream stream) { if (stream.device == mlx::core::Device::gpu) { metal::device(stream.device).get_command_encoder(stream.index); } } inline void check_error(MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; msg << "[METAL] Command buffer execution failed: " << cbuf->error()->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } } void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); auto& d = metal::device(s.device); auto command_buffer = d.get_command_buffer(s.index); auto outputs = arr.outputs(); { // If the array is a tracer hold a reference // to its inputs so they don't get donated std::vector inputs; if (arr.is_tracer()) { inputs = arr.inputs(); } debug_set_primitive_buffer_label(command_buffer, arr.primitive()); arr.primitive().eval_gpu(arr.inputs(), outputs); } std::unordered_set> buffers; for (auto& in : arr.inputs()) { buffers.insert(in.data_shared_ptr()); } for (auto& s : arr.siblings()) { buffers.insert(s.data_shared_ptr()); } // Remove the output if it was donated to by an input if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { buffers.erase(it); } if (d.command_buffer_needs_commit(s.index)) { d.end_encoding(s.index); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); check_error(cbuf); }); d.commit_command_buffer(s.index); } else { command_buffer->addCompletedHandler( [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); } } void finalize(Stream s) { auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); d.end_encoding(s.index); cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); d.commit_command_buffer(s.index); } void synchronize(Stream s) { auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); cb->retain(); d.end_encoding(s.index); d.commit_command_buffer(s.index); cb->waitUntilCompleted(); check_error(cb); cb->release(); } } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/metal/event.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/event.h" #include "mlx/backend/metal/device.h" #include "mlx/scheduler.h" namespace mlx::core { Event::Event(Stream stream) : stream_(stream) { auto dtor = [](void* ptr) { auto p = metal::new_scoped_memory_pool(); static_cast(ptr)->release(); }; auto p = metal::new_scoped_memory_pool(); event_ = std::shared_ptr( metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); if (event_ == nullptr) { throw std::runtime_error( "[Event::Event] Failed to create Metal shared event."); } } void Event::wait() { if (!static_cast(event_.get()) ->waitUntilSignaledValue(value(), -1)) { throw std::runtime_error("[Event::wait] Timed out"); } } void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } else { auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); command_buffer->encodeWait(static_cast(event_.get()), value()); command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); } } void Event::signal(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { static_cast(event_.get())->setSignaledValue(value()); }); } else { auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); command_buffer->encodeSignalEvent( static_cast(event_.get()), value()); command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); } } bool Event::is_signaled() const { return static_cast(event_.get())->signaledValue() >= value(); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/fence.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" #include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core { struct FenceImpl { FenceImpl() { auto d = metal::device(Device::gpu).mtl_device(); if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { use_fast = false; } else if (__builtin_available(macOS 15, iOS 18, *)) { use_fast = env::metal_fast_synch(); } if (!use_fast) { auto p = metal::new_scoped_memory_pool(); fence = static_cast(d->newSharedEvent()); } else { auto buf = allocator::malloc(sizeof(uint32_t)).ptr(); fence = static_cast(buf); cpu_value()[0] = 0; } } ~FenceImpl() { if (!use_fast) { // Wraps Metal SharedEvent auto p = metal::new_scoped_memory_pool(); static_cast(fence)->release(); } else { allocator::free(allocator::Buffer{static_cast(fence)}); } } bool use_fast{false}; uint32_t count{0}; void* fence; std::atomic_uint* cpu_value() { return static_cast( static_cast(fence)->contents()); } }; Fence::Fence(Stream) { auto dtor = [](void* ptr) { delete static_cast(ptr); }; fence_ = std::shared_ptr(new FenceImpl{}, dtor); } void Fence::wait(Stream stream, const array& x) { auto& f = *static_cast(fence_.get()); if (stream.device == Device::cpu) { scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { auto& f = *static_cast(fence_.get()); if (!f.use_fast) { if (!static_cast(f.fence)->waitUntilSignaledValue( count, -1)) { throw std::runtime_error("[Fence::wait] Timed out"); } return; } while (f.cpu_value()[0] < count) { } }); return; } auto& d = metal::device(stream.device); auto idx = stream.index; if (!f.use_fast) { d.end_encoding(idx); auto command_buffer = d.get_command_buffer(idx); command_buffer->encodeWait(static_cast(f.fence), f.count); command_buffer->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); return; } auto& compute_encoder = d.get_command_encoder(idx); // Register outputs to ensure that no kernels which depends on the // output starts before this one is done compute_encoder.register_output_array(x); auto kernel = d.get_kernel("fence_wait"); MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); auto buf = static_cast(f.fence); compute_encoder.set_buffer(buf, 0); compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } void Fence::update(Stream stream, const array& x, bool cross_device) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { auto& f = *static_cast(fence_.get()); if (!f.use_fast) { static_cast(f.fence)->setSignaledValue(count); return; } f.cpu_value()[0] = count; }); return; } auto& d = metal::device(stream.device); auto idx = stream.index; if (!f.use_fast) { d.end_encoding(idx); auto command_buffer = d.get_command_buffer(idx); command_buffer->encodeSignalEvent( static_cast(f.fence), f.count); command_buffer->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); return; } // Launch input visibility kernels auto& compute_encoder = d.get_command_encoder(idx); if (cross_device) { auto kernel = d.get_kernel("input_coherent"); uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); MTL::Size group_dims = MTL::Size(1024, 1, 1); MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_bytes(nthreads, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Barrier on previous kernels compute_encoder.barrier(); // Launch value update kernel auto kernel = d.get_kernel("fence_update"); MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); auto buf = static_cast(f.fence); compute_encoder.set_buffer(buf, 0); compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/fft.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include #include #include #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" namespace mlx::core { using MTLFC = std::tuple; #define MAX_STOCKHAM_FFT_SIZE 4096 #define MAX_RADER_FFT_SIZE 2048 #define MAX_BLUESTEIN_FFT_SIZE 2048 // Threadgroup memory batching improves throughput for small n #define MIN_THREADGROUP_MEM_SIZE 256 // For strided reads/writes, coalesce at least this many complex64s #define MIN_COALESCE_WIDTH 4 inline const std::vector supported_radices() { // Ordered by preference in decomposition. return {13, 11, 8, 7, 6, 5, 4, 3, 2}; } std::vector prime_factors(int n) { int z = 2; std::vector factors; while (z * z <= n) { if (n % z == 0) { factors.push_back(z); n /= z; } else { z++; } } if (n > 1) { factors.push_back(n); } return factors; } struct FourStepParams { bool required = false; bool first_step = true; int n1 = 0; int n2 = 0; }; // Forward Declaration void fft_op( const array& in, array& out, size_t axis, bool inverse, bool real, const FourStepParams four_step_params, bool inplace, const Stream& s); struct FFTPlan { int n = 0; // Number of steps for each radix in the Stockham decomposition std::vector stockham; // Number of steps for each radix in the Rader decomposition std::vector rader; // Rader factor, 1 if no rader factors int rader_n = 1; int bluestein_n = -1; // Four step FFT bool four_step = false; int n1 = 0; int n2 = 0; }; int next_fast_n(int n) { return next_power_of_2(n); } std::vector plan_stockham_fft(int n) { auto radices = supported_radices(); std::vector plan(radices.size(), 0); int orig_n = n; if (n == 1) { return plan; } for (int i = 0; i < radices.size(); i++) { int radix = radices[i]; // Manually tuned radices for powers of 2 if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { continue; } while (n % radix == 0) { plan[i] += 1; n /= radix; if (n == 1) { return plan; } } } throw std::runtime_error("Unplannable"); } FFTPlan plan_fft(int n) { auto radices = supported_radices(); std::set radices_set(radices.begin(), radices.end()); FFTPlan plan; plan.n = n; plan.rader = std::vector(radices.size(), 0); auto factors = prime_factors(n); int remaining_n = n; // Four Step FFT when N is too large for shared mem. if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { // For power's of two we have a fast, no transpose four step implementation. plan.four_step = true; // Rough heuristic for choosing faster powers of two when we can plan.n2 = n > 65536 ? 1024 : 64; plan.n1 = n / plan.n2; return plan; } else if (n > MAX_STOCKHAM_FFT_SIZE) { // Otherwise we use a multi-upload Bluestein's plan.four_step = true; plan.bluestein_n = next_fast_n(2 * n - 1); return plan; } for (int factor : factors) { // Make sure the factor is a supported radix if (radices_set.find(factor) == radices_set.end()) { // We only support a single Rader factor currently // TODO(alexbarron) investigate weirdness with large // Rader sizes -- possibly a compiler issue? if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) { plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE; plan.bluestein_n = next_fast_n(2 * n - 1); plan.stockham = plan_stockham_fft(plan.bluestein_n); plan.rader = std::vector(radices.size(), 0); return plan; } // See if we can use Rader's algorithm to Stockham decompose n - 1 auto rader_factors = prime_factors(factor - 1); for (int rf : rader_factors) { // We don't nest Rader's algorithm so if `factor - 1` // isn't Stockham decomposable we give up and do Bluestein's. if (radices_set.find(rf) == radices_set.end()) { plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE; plan.bluestein_n = next_fast_n(2 * n - 1); plan.stockham = plan_stockham_fft(plan.bluestein_n); plan.rader = std::vector(radices.size(), 0); return plan; } } plan.rader = plan_stockham_fft(factor - 1); plan.rader_n = factor; remaining_n /= factor; } } plan.stockham = plan_stockham_fft(remaining_n); return plan; } int compute_elems_per_thread(FFTPlan plan) { // Heuristics for selecting an efficient number // of threads to use for a particular mixed-radix FFT. auto n = plan.n; std::vector steps; auto radices = supported_radices(); steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end()); steps.insert(steps.end(), plan.rader.begin(), plan.rader.end()); std::set used_radices; for (int i = 0; i < steps.size(); i++) { int radix = radices[i % radices.size()]; if (steps[i] > 0) { used_radices.insert(radix); } } // Manual tuning for 7/11/13 if (used_radices.find(7) != used_radices.end() && (used_radices.find(11) != used_radices.end() || used_radices.find(13) != used_radices.end())) { return 7; } else if ( used_radices.find(11) != used_radices.end() && used_radices.find(13) != used_radices.end()) { return 11; } // TODO(alexbarron) Some really weird stuff is going on // for certain `elems_per_thread` on large composite n. // Possibly a compiler issue? if (n == 3159) return 13; if (n == 3645) return 5; if (n == 3969) return 7; if (n == 1982) return 5; if (used_radices.size() == 1) { return *(used_radices.begin()); } if (used_radices.size() == 2) { if (used_radices.find(11) != used_radices.end() || used_radices.find(13) != used_radices.end()) { return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2; } std::vector radix_vec(used_radices.begin(), used_radices.end()); return radix_vec[1]; } // In all other cases use the second smallest radix. std::vector radix_vec(used_radices.begin(), used_radices.end()); return radix_vec[1]; } // Rader int mod_exp(int x, int y, int n) { int out = 1; while (y) { if (y & 1) { out = out * x % n; } y >>= 1; x = x * x % n; } return out; } int primitive_root(int n) { auto factors = prime_factors(n - 1); for (int r = 2; r < n - 1; r++) { bool found = true; for (int factor : factors) { if (mod_exp(r, (n - 1) / factor, n) == 1) { found = false; break; } } if (found) { return r; } } return -1; } std::tuple compute_raders_constants( int rader_n, const Stream& s) { int proot = primitive_root(rader_n); // Fermat's little theorem int inv = mod_exp(proot, rader_n - 2, rader_n); std::vector g_q(rader_n - 1); std::vector g_minus_q(rader_n - 1); for (int i = 0; i < rader_n - 1; i++) { g_q[i] = mod_exp(proot, i, rader_n); g_minus_q[i] = mod_exp(inv, i, rader_n); } array g_q_arr(g_q.begin(), {rader_n - 1}); array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1}); std::vector> b_q(rader_n - 1); for (int i = 0; i < rader_n - 1; i++) { float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n; b_q[i] = std::exp(std::complex(0, pi_i)); } array b_q_fft({rader_n - 1}, complex64, nullptr, {}); b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes())); auto b_q_fft_ptr = reinterpret_cast*>(b_q_fft.data()); std::ptrdiff_t item_size = b_q_fft.itemsize(); size_t fft_size = rader_n - 1; // This FFT is always small (<4096, batch 1) so save some overhead // and do it on the CPU pocketfft::c2c( /* shape= */ {fft_size}, /* stride_in= */ {item_size}, /* stride_out= */ {item_size}, /* axes= */ {0}, /* forward= */ true, /* data_in= */ b_q.data(), /* data_out= */ b_q_fft_ptr, /* scale= */ 1.0f); return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr); } // Bluestein std::pair compute_bluestein_constants(int n, int bluestein_n) { // We need to calculate the Bluestein twiddle factors // in double precision for the overall numerical stability // of Bluestein's FFT algorithm to be acceptable. // // Metal doesn't support float64, so instead we // manually implement the required operations on cpu. // // In numpy: // w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2)) // w_q = np.fft.fft(1/w_k) // return w_k, w_q std::vector> w_k_vec(n); std::vector> w_q_vec(bluestein_n, 0); for (int i = -n + 1; i < n; i++) { double theta = pow(i, 2) * M_PI / (double)n; w_q_vec[i + n - 1] = std::exp(std::complex(0, theta)); if (i >= 0) { w_k_vec[i] = std::exp(std::complex(0, -theta)); } } array w_k({n}, complex64, nullptr, {}); w_k.set_data(allocator::malloc(w_k.nbytes())); std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data()); array w_q({bluestein_n}, complex64, nullptr, {}); w_q.set_data(allocator::malloc(w_q.nbytes())); auto w_q_ptr = reinterpret_cast*>(w_q.data()); std::ptrdiff_t item_size = w_q.itemsize(); size_t fft_size = bluestein_n; pocketfft::c2c( /* shape= */ {fft_size}, /* stride_in= */ {item_size}, /* stride_out= */ {item_size}, /* axes= */ {0}, /* forward= */ true, /* data_in= */ w_q_vec.data(), /* data_out= */ w_q_ptr, /* scale= */ 1.0f); return std::make_tuple(w_k, w_q); } void multi_upload_bluestein_fft( const array& in, array& out, size_t axis, bool inverse, bool real, FFTPlan& plan, std::vector& copies, const Stream& s) { // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // algorithm int n = inverse ? out.shape(axis) : in.shape(axis); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); copies.push_back(w_k); copies.push_back(w_q); auto temp_shape = inverse ? out.shape() : in.shape(); array temp(temp_shape, complex64, nullptr, {}); array temp1(temp_shape, complex64, nullptr, {}); if (real && !inverse) { // Convert float32->complex64 copy_gpu(in, temp, CopyType::General, s); copies.push_back(temp); } else if (real && inverse) { int back_offset = n % 2 == 0 ? 2 : 1; auto slice_shape = in.shape(); slice_shape[axis] -= back_offset; array slice_temp(slice_shape, complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {}); copies.push_back(conj_temp); Shape rstarts(in.ndim(), 0); Shape rstrides(in.ndim(), 1); rstarts[axis] = in.shape(axis) - back_offset; rstrides[axis] = -1; unary_op_gpu({in}, conj_temp, "Conjugate", s); slice_gpu(in, slice_temp, rstarts, rstrides, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); copies.push_back(temp); } else if (inverse) { unary_op_gpu({in}, temp, "Conjugate", s); copies.push_back(temp); } else { temp.copy_shared_buffer(in); } Strides b_strides(in.ndim(), 0); b_strides[axis] = 1; array w_k_broadcast(temp.shape(), complex64, nullptr, {}); w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); std::vector> pads; auto padded_shape = out.shape(); padded_shape[axis] = plan.bluestein_n; array pad_temp(padded_shape, complex64, nullptr, {}); auto zero = array(complex64_t{0.0f, 0.0f}); copies.push_back(zero); pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s); copies.push_back(pad_temp); array pad_temp1(padded_shape, complex64, nullptr, {}); fft_op( pad_temp, pad_temp1, axis, /*inverse=*/false, /*real=*/false, FourStepParams(), /*inplace=*/false, s); copies.push_back(pad_temp1); array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {}); w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); fft_op( pad_temp, pad_temp1, axis, /* inverse= */ true, /* real= */ false, FourStepParams(), /*inplace=*/true, s); int offset = plan.bluestein_n - (2 * n - 1); Shape starts(in.ndim(), 0); Shape strides(in.ndim(), 1); starts[axis] = plan.bluestein_n - offset - n; array temp2(temp_shape, complex64, nullptr, {}); slice_gpu(pad_temp1, temp2, starts, strides, s); binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s); if (real && !inverse) { Shape rstarts(in.ndim(), 0); Shape rstrides(in.ndim(), 1); slice_gpu(temp1, out, rstarts, strides, s); } else if (real && inverse) { Strides b_strides(in.ndim(), 0); auto inv_n = array({1.0f / n}, {1}, float32); array temp_float(out.shape(), out.dtype(), nullptr, {}); copies.push_back(temp_float); copies.push_back(inv_n); copies.push_back(temp1); copy_gpu(temp1, temp_float, CopyType::General, s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); } else if (inverse) { auto inv_n = array({1.0f / n}, {1}, complex64); array temp3(temp_shape, complex64, nullptr, {}); unary_op_gpu({temp1}, temp3, "Conjugate", s); binary_op_gpu({temp3, inv_n}, out, "Multiply", s); copies.push_back(inv_n); copies.push_back(temp1); copies.push_back(temp3); } else { out.copy_shared_buffer(temp1); } } void four_step_fft( const array& in, array& out, size_t axis, bool inverse, bool real, FFTPlan& plan, std::vector& copies, const Stream& s, bool in_place) { if (plan.bluestein_n == -1) { // Fast no transpose implementation for powers of 2. FourStepParams four_step_params = { /* required= */ true, /* first_step= */ true, plan.n1, plan.n2}; auto temp_shape = (real && inverse) ? out.shape() : in.shape(); array temp(temp_shape, complex64, nullptr, {}); fft_op( in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); four_step_params.first_step = false; fft_op( temp, out, axis, inverse, real, four_step_params, /*inplace=*/in_place, s); copies.push_back(temp); } else { multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); } } void fft_op( const array& in, array& out, size_t axis, bool inverse, bool real, const FourStepParams four_step_params, bool inplace, const Stream& s) { auto& d = metal::device(s.device); size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); if (n == 1) { out.copy_shared_buffer(in); return; } if (four_step_params.required) { // Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2; } // Make sure that the array is contiguous and has stride 1 in the FFT dim std::vector copies; auto check_input = [&axis, &copies, &s](const array& x) { // TODO: Pass the strides to the kernel so // we can avoid the copy when x is not contiguous. bool no_copy = x.strides()[axis] == 1 && (x.flags().row_contiguous || x.flags().col_contiguous); if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); Strides strides; int64_t cur_stride = x.shape(axis); for (int a = 0; a < x.ndim(); a++) { if (a == axis) { strides.push_back(1); } else { strides.push_back(cur_stride); cur_stride *= x.shape(a); } } auto flags = x.flags(); auto [data_size, is_row_contiguous, is_col_contiguous] = check_contiguity(x.shape(), strides); flags.col_contiguous = is_col_contiguous; flags.row_contiguous = is_row_contiguous; flags.contiguous = data_size == x_copy.size(); x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copies.push_back(x_copy); return x_copy; } }; const array& in_contiguous = check_input(in); // real to complex: n -> (n/2)+1 // complex to real: (n/2)+1 -> n auto out_strides = in_contiguous.strides(); size_t out_data_size = in_contiguous.data_size(); if (in.shape(axis) != out.shape(axis)) { for (int i = 0; i < out_strides.size(); i++) { if (out_strides[i] != 1) { out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis); } } out_data_size = out_data_size / in.shape(axis) * out.shape(axis); } auto plan = plan_fft(n); if (plan.four_step) { four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace); d.add_temporaries(std::move(copies), s.index); return; } // TODO: allow donation here if (!inplace) { out.set_data( allocator::malloc(out.nbytes()), out_data_size, out_strides, in_contiguous.flags()); } auto radices = supported_radices(); int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n; // Setup function constants bool power_of_2 = is_power_of_2(fft_size); auto make_int = [](int* a, int i) { return std::make_tuple(a, MTL::DataType::DataTypeInt, i); }; auto make_bool = [](bool* a, int i) { return std::make_tuple(a, MTL::DataType::DataTypeBool, i); }; std::vector func_consts = { make_bool(&inverse, 0), make_bool(&power_of_2, 1)}; // Start of radix/rader step constants int index = 4; for (int i = 0; i < plan.stockham.size(); i++) { func_consts.push_back(make_int(&plan.stockham[i], index)); index += 1; } for (int i = 0; i < plan.rader.size(); i++) { func_consts.push_back(make_int(&plan.rader[i], index)); index += 1; } int elems_per_thread = compute_elems_per_thread(plan); func_consts.push_back(make_int(&elems_per_thread, 2)); int rader_m = n / plan.rader_n; func_consts.push_back(make_int(&rader_m, 3)); // The overall number of FFTs we're going to compute for this input size_t size = out.dtype() == float32 ? out.size() : in.size(); if (real && inverse && four_step_params.required) { size = out.size(); } int total_batch_size = size / n; int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread; // We batch among threadgroups for improved efficiency when n is small int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1); if (four_step_params.required) { // Require a threadgroup batch size of at least 4 for four step FFT // so we can coalesce the memory accesses. threadgroup_batch_size = std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH); } int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size); // FFTs up to 2^20 are currently supported assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE); // ceil divide int batch_size = (total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size; if (real && !four_step_params.required) { // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; // Only required by four step int step = -1; { std::ostringstream kname; std::string inv_string = inverse ? "true" : "false"; std::string real_string = real ? "true" : "false"; std::string func_name; if (plan.bluestein_n > 0) { kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; func_name = "bluestein_fft"; } else if (plan.rader_n > 1) { kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; func_name = "rader_fft"; } else if (four_step_params.required) { step = four_step_params.first_step ? 0 : 1; kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str << "_" << step << "_" << real_string; func_name = "four_step_fft"; } else { kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; func_name = "fft"; } std::string base_name = kname.str(); // We use a specialized kernel for each FFT size kname << "_n" << fft_size << "_inv_" << inverse; std::string hash_name = kname.str(); auto template_def = func_name == "four_step_fft" ? get_template_definition( base_name, func_name, threadgroup_mem_size, in_type_str, out_type_str, step, real) : get_template_definition( base_name, func_name, threadgroup_mem_size, in_type_str, out_type_str); auto kernel = get_fft_kernel(d, base_name, hash_name, func_consts, template_def); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_output_array(out, 1); if (plan.bluestein_n > 0) { // Precomputed twiddle factors for Bluestein's auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); copies.push_back(w_q); copies.push_back(w_k); compute_encoder.set_input_array(w_q, 2); // w_q compute_encoder.set_input_array(w_k, 3); // w_k compute_encoder.set_bytes(n, 4); compute_encoder.set_bytes(plan.bluestein_n, 5); compute_encoder.set_bytes(total_batch_size, 6); } else if (plan.rader_n > 1) { auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); copies.push_back(b_q); copies.push_back(g_q); copies.push_back(g_minus_q); compute_encoder.set_input_array(b_q, 2); compute_encoder.set_input_array(g_q, 3); compute_encoder.set_input_array(g_minus_q, 4); compute_encoder.set_bytes(n, 5); compute_encoder.set_bytes(total_batch_size, 6); compute_encoder.set_bytes(plan.rader_n, 7); } else if (four_step_params.required) { compute_encoder.set_bytes(four_step_params.n1, 2); compute_encoder.set_bytes(four_step_params.n2, 3); compute_encoder.set_bytes(total_batch_size, 4); } else { compute_encoder.set_bytes(n, 2); compute_encoder.set_bytes(total_batch_size, 3); } auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto grid_dims = MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); } void fft_op( const array& in, array& out, size_t axis, bool inverse, bool real, bool inplace, const Stream& s) { fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s); } void nd_fft_op( const array& in, array& out, const std::vector& axes, bool inverse, bool real, const Stream& s) { // Perform ND FFT on GPU as a series of 1D FFTs auto temp_shape = inverse ? in.shape() : out.shape(); std::vector temp_arrs; temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector{}); if (axes.size() > 2) { temp_arrs.emplace_back( temp_shape, complex64, nullptr, std::vector{}); } for (int i = axes.size() - 1; i >= 0; i--) { int reverse_index = axes.size() - i - 1; // For 5D and above, we don't want to reallocate our two temporary arrays bool inplace = reverse_index >= 3 && i != 0; // Opposite order for fft vs ifft int index = inverse ? reverse_index : i; size_t axis = axes[index]; // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. bool step_real = (real && index == axes.size() - 1); const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2]; array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } auto& d = metal::device(s.device); d.add_temporaries(std::move(temp_arrs), s.index); } void FFT::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& in = inputs[0]; if (axes_.size() > 1) { nd_fft_op(in, out, axes_, inverse_, real_, s); } else { fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/hadamard.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M // using the hadamard matrices above // // e.g. m = 2 // METAL_FUNC void hadamard_m(thread float *x) { // float tmp[2]; // tmp[0] = + x[0] + x[1]; // tmp[1] = + x[0] - x[1]; // for (int i = 0; i < 2; i++) { x[i] = tmp[i]; } // } // auto h_matrices = hadamard_matrices(); auto& matrix = h_matrices[m]; std::ostringstream source; source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl; if (m == 1) { source << "}" << std::endl; return source.str(); } source << " float tmp[" << m << "];" << std::endl; auto start = 1; auto end = matrix.find('\n', start); int index = 0; while (end != std::string_view::npos) { source << " tmp[" << index << "] = "; auto row = matrix.substr(start, end - start); for (int i = 0; i < row.length(); i++) { source << " " << row[i] << " x[" << i << "]"; } source << ";" << std::endl; start = end + 1; end = matrix.find('\n', start); index++; } source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }" << std::endl; source << "}" << std::endl; return source.str(); } void hadamard_mn_contiguous( const array& x, array& y, int m, int n1, int n2, float scale, metal::Device& d, const Stream& s) { int n = n1 * n2; int read_width_n1 = n1 == 2 ? 2 : 4; int read_width_n2 = n2 == 2 ? 2 : 4; int read_width_m = (n == 2 || m == 28) ? 2 : 4; int max_radix_1 = std::min(n1, 16); int max_radix_2 = std::min(n2, 16); float scale_n1 = 1.0; float scale_n2 = (m == 1) ? scale : 1.0; float scale_m = scale; // n2 is a row contiguous power of 2 hadamard transform MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); // n1 is a strided power of 2 hadamard transform with stride n2 MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); // m is a strided hadamard transform with stride n = n1 * n2 MTL::Size group_dims_m( std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); MTL::Size grid_dims_m( group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); // Make the kernel std::string kname; kname.reserve(32); concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); auto lib = d.get_library(kname, [&]() { std::string kernel; concatenate( kernel, metal::utils(), gen_hadamard_codelet(m), metal::hadamard(), get_template_definition( "n2" + kname, "hadamard_n", get_type_string(x.dtype()), n2, max_radix_2, read_width_n2)); if (n1 > 1) { kernel += get_template_definition( "n1" + kname, "hadamard_n", get_type_string(x.dtype()), n1, max_radix_1, read_width_n1, n2); } if (m > 1) { kernel += get_template_definition( "m" + kname, "hadamard_m", get_type_string(x.dtype()), n, m, read_width_m); } return kernel; }); // Launch the strided transform for n1 if (n1 > 1) { auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_output_array(y, 1); compute_encoder.set_bytes(scale_n1, 2); compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } // Launch the transform for n2 auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel("n2" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(n1 > 1 ? y : x, 0); compute_encoder.set_output_array(y, 1); compute_encoder.set_bytes(scale_n2, 2); compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); // Launch the strided transform for m if (m > 1) { auto kernel = d.get_kernel("m" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(y, 0); compute_encoder.set_output_array(y, 1); compute_encoder.set_bytes(scale_m, 2); compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); } } void Hadamard::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; // Split the hadamard transform so that all of them work on vectors smaller // than 8192 elements. // // We decompose it in the following way: // // n = m * n1 * n2 = m * 2^k1 * 2^k2 // // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 auto [n, m] = decompose_hadamard(in.shape().back()); int n1 = 1, n2 = n; if (n > 8192) { for (n2 = 2; n2 * n2 < n; n2 *= 2) { } n1 = n / n2; } if (in.flags().row_contiguous) { if (in.is_donatable()) { out.copy_shared_buffer(in); } else { out.set_data(allocator::malloc(out.nbytes())); } hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); } else { copy_gpu(in, out, CopyType::General, s); hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/indexing.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/dtype.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { constexpr int METAL_MAX_INDEX_ARRAYS = 20; std::pair make_index_args( const std::string& idx_type, int nidx) { std::ostringstream idx_args; std::ostringstream idx_arr; for (int i = 0; i < nidx; ++i) { idx_args << fmt::format( "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i); idx_arr << fmt::format("idx{0}", i); if (i < nidx - 1) { idx_args << "\n"; idx_arr << ","; } } return {idx_args.str(), idx_arr.str()}; } template inline std::string make_op(typename T::ReduceType r, const std::string& dt) { switch (r) { case T::None: return "None"; case T::Sum: return fmt::format("Sum<{0}>", dt); case T::Prod: return fmt::format("Prod<{0}>", dt); case T::Max: return fmt::format("Max<{0}>", dt); case T::Min: return fmt::format("Min<{0}>", dt); } } void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; int nidx = inputs.size() - 1; if (nidx > METAL_MAX_INDEX_ARRAYS) { std::ostringstream msg; msg << "[Gather::eval_gpu] Gathering with more than " << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported."; throw std::runtime_error(msg.str()); } out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } auto& s = stream(); auto& d = metal::device(s.device); size_t slice_size = 1; for (auto s : slice_sizes_) { slice_size *= s; } bool large_index = nidx && inputs[1].size() > INT32_MAX; bool large_src = src.size() > INT32_MAX; bool large_out = out.size() > INT32_MAX; bool large = large_index || large_src || large_out; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 && inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) { int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1; auto& indices = inputs[1]; std::string kernel_name = fmt::format( "gather_front{0}_{1}_{2}_{3}", type_to_name(out), idx_type_name, large ? "int64_t" : "int", work_per_thread); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::gather_front(); kernel_source += get_template_definition( kernel_name, "gather_front", get_type_string(out.dtype()), get_type_string(indices.dtype()), large ? "int64_t" : "int", work_per_thread); return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; size_t dim_y = indices.size(); auto group_dims = get_block_dims(dim_x, dim_y, 1); MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1); compute_encoder.set_input_array(src, 0); compute_encoder.set_input_array(indices, 1); compute_encoder.set_output_array(out, 2); compute_encoder.set_bytes(slice_size, 3); compute_encoder.set_bytes(src.shape(0), 4); compute_encoder.dispatch_threads(grid_dims, group_dims); return; } int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t ndim = src.ndim(); std::string kernel_name = fmt::format( "gather{0}{1}_{2}_{3}_{4}", type_to_name(out), idx_type_name, nidx, idx_ndim, large ? "int64_t" : "int"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::gather(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = nidx ? get_type_string(inputs[1].dtype()) : "bool"; auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); // Index dimension specializations kernel_source += fmt::format( gather_kernels, type_to_name(out) + idx_type_name, out_type_str, idx_type_str, nidx, idx_args, idx_arr, idx_ndim, large ? "int64_t" : "int"); return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); // Launch 3D grid of threads // First two dimensions for the indices, the last one for the slice size_t dim0 = 1; size_t dim1 = 1; if (nidx) { if (inputs[1].ndim() >= 1) { dim0 = inputs[1].shape(0); } if (inputs[1].ndim() >= 2) { dim1 = inputs[1].size() / dim0; } } size_t dim2 = slice_size; auto group_dims = get_block_dims(dim0, dim1, dim2); MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2); // Collect all idx shapes and strides into one place std::vector idx_shapes; std::vector idx_strides; std::vector idx_contigs; for (int i = 0; i < nidx; ++i) { idx_shapes.insert( idx_shapes.end(), inputs[i + 1].shape().begin(), inputs[i + 1].shape().end()); idx_strides.insert( idx_strides.end(), inputs[i + 1].strides().begin(), inputs[i + 1].strides().end()); idx_contigs.push_back(inputs[i + 1].flags().row_contiguous); } // Set all the buffers compute_encoder.set_input_array(src, 0); compute_encoder.set_output_array(out, 1); // Set source info compute_encoder.set_vector_bytes(src.shape(), 2); compute_encoder.set_vector_bytes(src.strides(), 3); compute_encoder.set_bytes(ndim, 4); compute_encoder.set_vector_bytes(slice_sizes_, 5); compute_encoder.set_vector_bytes(axes_, 6); // Set index info // // We don't need to check for empty idx_shapes because gather has a // idx_ndim == 0 specialization compute_encoder.set_vector_bytes(idx_shapes, 7); compute_encoder.set_vector_bytes(idx_strides, 8); compute_encoder.set_vector_bytes(idx_contigs, 9); compute_encoder.set_bytes(idx_ndim, 10); // Set index buffers for (int i = 0; i < nidx; ++i) { compute_encoder.set_input_array(inputs[i + 1], 20 + i); } // Launch grid compute_encoder.dispatch_threads(grid_dims, group_dims); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (size_of(out.dtype()) == 8) { std::ostringstream msg; msg << "[Scatter::eval_gpu] Does not support " << out.dtype(); throw std::invalid_argument(msg.str()); } int nidx = axes_.size(); if (nidx > METAL_MAX_INDEX_ARRAYS) { std::ostringstream msg; msg << "[Scatter::eval_gpu] Gathering with more than " << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported."; throw std::runtime_error(msg.str()); } // Copy src into out CopyType copy_type; if (inputs[0].data_size() == 1) { copy_type = CopyType::Scalar; } else if (inputs[0].flags().row_contiguous) { copy_type = CopyType::Vector; } else { copy_type = CopyType::General; } copy_gpu(inputs[0], out, copy_type); auto& upd = inputs.back(); // Empty update if (upd.size() == 0) { return; } // Get stream auto& s = stream(); auto& d = metal::device(s.device); int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t idx_size = nidx ? inputs[1].size() : 1; auto idx_to_out = idx_size / out.size(); int nwork; if (idx_ndim <= 1 || idx_to_out < 1) { nwork = 1; } else if (idx_to_out <= 4) { nwork = 4; } else if (idx_to_out < 16) { nwork = 8; } else if (idx_to_out < 32) { nwork = 16; } else { nwork = 32; } std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string op_name; switch (reduce_type_) { case Scatter::None: op_name = "none"; break; case Scatter::Sum: op_name = "sum"; break; case Scatter::Prod: op_name = "prod"; break; case Scatter::Max: op_name = "max"; break; case Scatter::Min: op_name = "min"; break; } auto upd_contig = upd.flags().row_contiguous; bool large_out = out.size() > INT32_MAX; bool large_idx = nidx && (inputs[1].size() > INT32_MAX); bool large_upd = upd.size() > INT32_MAX; bool large = large_out || large_idx || large_upd; std::string kernel_name = fmt::format( "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", type_to_name(out), idx_type_name, op_name, nidx, upd_contig ? "updc_true" : "updc_false", nwork, large ? "int64_t" : "int"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = nidx ? get_type_string(inputs[1].dtype()) : "bool"; std::string op_type = make_op(reduce_type_, out_type_str); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); kernel_source += fmt::format( scatter_kernels, type_to_name(out) + idx_type_name + "_" + op_name, out_type_str, idx_type_str, op_type, nidx, idx_args, idx_arr, upd_contig, nwork, large ? "int64_t" : "int"); return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); size_t nthreads = upd.size(); compute_encoder.set_compute_pipeline_state(kernel); // Set all the buffers compute_encoder.set_input_array(upd, 1); compute_encoder.set_output_array(out, 2); // Set update info size_t upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } // Collect all idx shapes and strides into one place Shape idx_shapes; Strides idx_strides; // To access .data() use char instead of bool // bool is 1 byte in Metal so this is safe std::vector idx_contigs; for (int i = 0; i < nidx; ++i) { idx_shapes.insert( idx_shapes.end(), inputs[i + 1].shape().begin(), inputs[i + 1].shape().end()); idx_strides.insert( idx_strides.end(), inputs[i + 1].strides().begin(), inputs[i + 1].strides().end()); idx_contigs.push_back(inputs[i + 1].flags().row_contiguous); } if (upd_ndim == 0) { // Need placeholders so Metal doesn't complain int shape_ = 0; int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 3); compute_encoder.set_bytes(stride_, 4); } else { compute_encoder.set_vector_bytes(upd.shape(), 3); compute_encoder.set_vector_bytes(upd.strides(), 4); } compute_encoder.set_bytes(upd_ndim, 5); compute_encoder.set_bytes(upd_size, 6); // Set output info size_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't complain int shape_ = 0; int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 7); compute_encoder.set_bytes(stride_, 8); } else { compute_encoder.set_vector_bytes(out.shape(), 7); compute_encoder.set_vector_bytes(out.strides(), 8); } compute_encoder.set_bytes(out_ndim, 9); compute_encoder.set_vector_bytes(axes_, 10); // Set index info if (idx_ndim == 0) { // Add a 0 in idx_shapes and strides to avoid the missing buffer binding // error in the metal API. idx_shapes.push_back(0); idx_strides.push_back(0); idx_contigs.push_back(false); } compute_encoder.set_vector_bytes(idx_shapes, 11); compute_encoder.set_vector_bytes(idx_strides, 12); compute_encoder.set_vector_bytes(idx_contigs, 13); compute_encoder.set_bytes(idx_ndim, 14); compute_encoder.set_bytes(idx_size, 15); // Set index buffers for (int i = 0; i < nidx; ++i) { compute_encoder.set_input_array(inputs[i + 1], 20 + i); } // Launch grid auto grid_y = (nthreads / upd_size); grid_y = (grid_y + nwork - 1) / nwork; MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); } MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; auto& idx = inputs[1]; out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } auto& s = stream(); auto& d = metal::device(s.device); size_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string kernel_name = fmt::format( "gather_axis{0}{1}_{2}", type_to_name(out), type_to_name(idx), large ? "int64_t" : "int"); std::string lib_name = kernel_name; kernel_name += src.flags().row_contiguous ? "c" : "nc"; kernel_name += idx.flags().row_contiguous ? "c" : "nc"; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::gather_axis(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = get_type_string(idx.dtype()); for (int i = 0; i < 4; ++i) { bool sc = i & 1; bool ic = i & 2; kernel_source += get_template_definition( lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"), "gather_axis", out_type_str, idx_type_str, large ? "int64_t" : "int", sc ? "true" : "false", ic ? "true" : "false"); } return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } for (int i = axis_ + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } int idx_ax_size = idx.shape(axis_); auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre); MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre); // Set all the buffers compute_encoder.set_input_array(src, 0); compute_encoder.set_input_array(idx, 1); compute_encoder.set_output_array(out, 2); // Set source info compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4); compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(src.shape(axis_), 8); compute_encoder.set_bytes(src.strides(axis_), 9); compute_encoder.set_bytes(idx.strides(axis_), 10); compute_encoder.dispatch_threads(grid_dims, group_dims); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; auto& idx = inputs[1]; auto& upd = inputs[2]; // Copy src into out CopyType copy_type; if (src.data_size() == 1) { copy_type = CopyType::Scalar; } else if (src.flags().row_contiguous) { copy_type = CopyType::Vector; } else { copy_type = CopyType::General; } copy_gpu(src, out, copy_type); // Empty update if (upd.size() == 0) { return; } auto& s = stream(); auto& d = metal::device(s.device); size_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string op_name; switch (reduce_type_) { case ScatterAxis::None: op_name = "none"; break; case ScatterAxis::Sum: op_name = "sum"; break; } std::string kernel_name = fmt::format( "scatter_axis{0}{1}_{2}_{3}", type_to_name(out), type_to_name(idx), op_name, large ? "int64_t" : "int"); std::string lib_name = kernel_name; kernel_name += upd.flags().row_contiguous ? "c" : "nc"; kernel_name += idx.flags().row_contiguous ? "c" : "nc"; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::reduce_utils(); kernel_source += metal::scatter_axis(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = get_type_string(idx.dtype()); std::string op_type; switch (reduce_type_) { case ScatterAxis::None: op_type = "None"; break; case ScatterAxis::Sum: op_type = "Sum<" + out_type_str + ">"; break; } for (int i = 0; i < 4; ++i) { bool uc = i & 1; bool ic = i & 2; kernel_source += get_template_definition( lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"), "scatter_axis", out_type_str, idx_type_str, large ? "int64_t" : "int", op_type, uc ? "true" : "false", ic ? "true" : "false"); } return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } for (int i = axis_ + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } int idx_ax_size = idx.shape(axis_); auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre); MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre); // Set all the buffers compute_encoder.set_input_array(upd, 0); compute_encoder.set_input_array(idx, 1); compute_encoder.set_output_array(out, 2); // Set source info if (ndim > 1) { compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); } else { // The following will be ignored in the kernel but we still have to set // some value so that metal validation passes. compute_encoder.set_vector_bytes(idx.shape(), 3); compute_encoder.set_vector_bytes(upd.strides(), 4); compute_encoder.set_vector_bytes(idx.strides(), 5); } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); compute_encoder.set_bytes(upd.strides(axis_), 9); compute_encoder.set_bytes(idx.strides(axis_), 10); compute_encoder.dispatch_threads(grid_dims, group_dims); } void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { const array& dst = inputs[0]; const array& mask = inputs[1]; const array& src = inputs[2]; auto& s = stream(); auto& d = metal::device(s.device); const size_t total = mask.size(); const CopyType ct = (total == 1) ? CopyType::Scalar : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy_gpu(dst, out, ct, s); if (total == 0) { return; } array mask_flat = flatten_in_eval(mask, 1, -1, s); if (mask_flat.data() != mask.data()) { d.add_temporary(mask_flat, s.index); } if (!mask_flat.flags().row_contiguous) { mask_flat = contiguous_copy_gpu(mask_flat, s); d.add_temporary(mask_flat, s.index); } // Prefix (exclusive) of mask → scatter_offsets array scatter_offsets(mask_flat.shape(), uint32, nullptr, {}); scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes())); d.add_temporary(scatter_offsets, s.index); scan_gpu_inplace( mask_flat, scatter_offsets, Scan::Sum, /*axis=*/1, /*reverse=*/false, /*inclusive=*/false, s); // Kernel selection/build static constexpr std::string_view kBaseName = "masked_assign"; const std::string dtype_tag = type_to_name(out.dtype()); const std::string value_type = get_type_string(out.dtype()); const std::string contiguous = (src.flags().row_contiguous) ? "true" : "false"; const std::string kernel_name = fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous); auto lib = d.get_library(kernel_name, [&]() { std::string source = metal::utils(); source += metal::masked_scatter(); source += fmt::format(masked_assign_kernel, kernel_name, value_type, contiguous); return source; }); auto kernel = d.get_kernel(kernel_name, lib); // Binding int bind_idx = 0; const int ndim = static_cast(src.ndim()); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(mask_flat, bind_idx++); compute_encoder.set_input_array(scatter_offsets, bind_idx++); compute_encoder.set_input_array(src, bind_idx++); compute_encoder.set_output_array(out, bind_idx++); compute_encoder.set_vector_bytes(src.shape(), bind_idx++); compute_encoder.set_vector_bytes(src.strides(), bind_idx++); compute_encoder.set_bytes(ndim, bind_idx++); compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++); compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++); // Dispatch auto group_dims = get_block_dims(total, 1, 1); MTL::Size grid_dims(total, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; auto& upd = inputs[1]; if (upd.size() == 0) { out.copy_shared_buffer(in); return; } auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); auto [data_offset, out_strides] = prepare_slice(out, start_indices_, strides_); // Do copy if (reduce_type_ == SliceUpdate::None) { copy_gpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const Shape& data_shape = */ upd.shape(), /* const Strides& i_strides = */ upd.strides(), /* const Strides& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, /* CopyType ctype = */ CopyType::GeneralGeneral, /* const Stream& s = */ stream()); return; } std::string op_name; switch (reduce_type_) { case SliceUpdate::None: op_name = "none"; break; case SliceUpdate::Sum: op_name = "sum"; break; case SliceUpdate::Prod: op_name = "prod"; break; case SliceUpdate::Max: op_name = "max"; break; case SliceUpdate::Min: op_name = "min"; break; } bool upd_contiguous = upd.flags().row_contiguous; bool upd_scalar = upd.data_size() == 1; Shape shape; std::vector strides; if (upd_scalar) { std::tie(shape, strides) = collapse_contiguous_dims(upd.shape(), {out_strides, out_strides}); } else { std::tie(shape, strides) = collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); } int ndim_constant = shape.size(); if (ndim_constant > 3) { ndim_constant = 0; } int nwork = 1; if (shape.back() % 4 == 0) { nwork = 4; } else if (shape.back() % 2 == 0) { nwork = 2; } auto [ds, rc, cc] = check_contiguity(shape, strides[1]); bool out_contiguous = rc; bool large = upd.size() > INT32_MAX; std::string kernel_name = fmt::format( "slice_update_{0}_{1}{2}_{3}_{4}_{5}_nw{6}_nd{7}", op_name, type_to_name(out), large ? "int64_t" : "int", out_contiguous ? "oc_true" : "oc_false", upd_contiguous ? "updc_true" : "updc_false", upd_scalar ? "upds_true" : "upds_false", nwork, ndim_constant); auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_name, [&]() { std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); std::string out_type = get_type_string(out.dtype()); std::string op_type = make_op(reduce_type_, out_type); kernel_source += fmt::format( slice_update_op_kernel, kernel_name, out_type, large ? "int64_t" : "int", op_type, out_contiguous, upd_contiguous, upd_scalar, nwork, ndim_constant); return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); // Set all the buffers int ndim = shape.size(); int64_t size = upd.size(); compute_encoder.set_input_array(upd, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_vector_bytes(shape, 2); compute_encoder.set_vector_bytes(strides[0], 3); compute_encoder.set_bytes(ndim, 4); compute_encoder.set_bytes(size, 5); compute_encoder.set_vector_bytes(strides[1], 6); compute_encoder.set_bytes(data_offset, 7); // Launch grid int64_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; int64_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; int64_t rest = size / (dim0 * dim1); dim0 /= nwork; auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/jit/includes.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once namespace mlx::core::metal { const char* utils(); const char* binary_ops(); const char* unary_ops(); const char* ternary_ops(); const char* reduce_utils(); const char* gather(); const char* scatter(); const char* masked_scatter(); const char* arange(); const char* unary(); const char* binary(); const char* binary_two(); const char* copy(); const char* fft(); const char* gather_axis(); const char* gather_front(); const char* hadamard(); const char* logsumexp(); const char* quantized_utils(); const char* quantized(); const char* fp_quantized(); const char* ternary(); const char* scan(); const char* scatter_axis(); const char* softmax(); const char* sort(); const char* reduce(); const char* gemm(); const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); const char* steel_gemm_gather(); const char* steel_gemm_segmented(); const char* conv(); const char* steel_conv(); const char* steel_conv_3d(); const char* steel_conv_general(); const char* gemv_masked(); const char* steel_attention(); const char* gemm_nax(); const char* steel_gemm_fused_nax(); const char* steel_gemm_gather_nax(); const char* steel_gemm_splitk_nax(); const char* quantized_nax(); const char* fp_quantized_nax(); const char* steel_attention_nax(); } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/jit/indexing.h ================================================ // Copyright © 2023-2024 Apple Inc. constexpr std::string_view gather_kernels = R"( [[kernel]] void gather{0}_{3}_{6}_{7}( const device {1}* src [[buffer(0)]], device {1}* out [[buffer(1)]], const constant int* src_shape [[buffer(2)]], const constant int64_t* src_strides [[buffer(3)]], const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], const constant int* idx_shapes [[buffer(7)]], const constant int64_t* idx_strides [[buffer(8)]], const constant bool* idx_contigs [[buffer(9)]], const constant int& idx_ndim [[buffer(10)]], {4} uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) {{ Indices<{2}, {3}> idxs{{ {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; return gather_impl<{1}, {2}, {3}, {6}, {7}>( src, out, src_shape, src_strides, src_ndim, slice_sizes, axes, idxs, index, grid_dim); }} )"; constexpr std::string_view scatter_kernels = R"( [[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}( const device {1}* updates [[buffer(1)]], device mlx_atomic<{1}>* out [[buffer(2)]], const constant int* upd_shape [[buffer(3)]], const constant int64_t* upd_strides [[buffer(4)]], const constant size_t& upd_ndim [[buffer(5)]], const constant size_t& upd_size [[buffer(6)]], const constant int* out_shape [[buffer(7)]], const constant int64_t* out_strides [[buffer(8)]], const constant size_t& out_ndim [[buffer(9)]], const constant int* axes [[buffer(10)]], const constant int* idx_shapes [[buffer(11)]], const constant int64_t* idx_strides [[buffer(12)]], const constant bool* idx_contigs [[buffer(13)]], const constant int& idx_ndim [[buffer(14)]], const constant size_t& idx_size [[buffer(15)]], {5} uint2 gid [[thread_position_in_grid]]) {{ Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>( updates, out, upd_shape, upd_strides, upd_ndim, upd_size, out_shape, out_strides, out_ndim, axes, idx_size, idxs, gid); }} )"; constexpr std::string_view masked_assign_kernel = R"( template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>; )"; constexpr std::string_view slice_update_op_kernel = R"( template [[host_name("{0}")]] [[kernel]] decltype(slice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>) slice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>; )"; ================================================ FILE: mlx/backend/metal/jit_kernels.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" using namespace fmt::literals; namespace mlx::core { MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { auto lib = d.get_library(kernel_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::arange(); kernel_source += get_template_definition( kernel_name, "arange", get_type_string(out.dtype())); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); if (get_work_per_thread(in_type) > 1) { kernel_source += get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } void append_binary_kernels( const std::string& lib_name, Dtype in_type, Dtype out_type, const char* op, std::string& kernel_source) { const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, {"g1large", "binary_g_nd1"}, {"g2large", "binary_g_nd2"}, {"g3large", "binary_g_nd3"}, }}; auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); for (auto& [name, func] : kernel_types) { kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } kernel_source += get_template_definition( "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); kernel_source += get_template_definition( "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); kernel_source += get_template_definition( "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); if (get_work_per_thread(in_type) > 1) { kernel_source += get_template_definition( "vsn_" + lib_name, "binary_vs", in_t, out_t, op); kernel_source += get_template_definition( "svn_" + lib_name, "binary_sv", in_t, out_t, op); kernel_source += get_template_definition( "vvn_" + lib_name, "binary_vv", in_t, out_t, op); } kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4); } MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; kernel_source = metal::utils(); concatenate(kernel_source, metal::binary_ops(), metal::binary()); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::binary_ops(), metal::binary_two()); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); const std::array, 3> kernel_types = {{ {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, {"g3large", "ternary_g_nd3"}, }}; for (auto& [name, func] : kernel_types) { kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } kernel_source += get_template_definition( "v2_" + lib_name, "ternary_v2", t_str, op, false, false); kernel_source += get_template_definition( "sv2_" + lib_name, "ternary_v2", t_str, op, true, false); kernel_source += get_template_definition( "vs2_" + lib_name, "ternary_v2", t_str, op, false, true); if (get_work_per_thread(type) > 1) { kernel_source += get_template_definition( "vn_" + lib_name, "ternary_v", t_str, op, false, false); kernel_source += get_template_definition( "svn_" + lib_name, "ternary_v", t_str, op, true, false); kernel_source += get_template_definition( "vsn_" + lib_name, "ternary_v", t_str, op, false, true); } kernel_source += get_template_definition( "v_" + lib_name, "ternary_v", t_str, op, false, false, 1); kernel_source += get_template_definition( "sv_" + lib_name, "ternary_v", t_str, op, true, false, 1); kernel_source += get_template_definition( "vs_" + lib_name, "ternary_v", t_str, op, false, true, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "ternary_g_nd2", t_str, op, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "ternary_g_nd3", t_str, op, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "ternary_g", t_str, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "ternary_g", t_str, op, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition( "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); kernel_source += get_template_definition( "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); if (get_work_per_thread(out.dtype()) > 1) { kernel_source += get_template_definition( "sn_" + lib_name, "copy_s", in_type, out_type); kernel_source += get_template_definition( "vn_" + lib_name, "copy_v", in_type, out_type); } kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "g1large_" + lib_name, "copy_g_nd1", in_type, out_type); kernel_source += get_template_definition( "g2large_" + lib_name, "copy_g_nd2", in_type, out_type); kernel_source += get_template_definition( "g3large_" + lib_name, "copy_g_nd3", in_type, out_type); kernel_source += get_template_definition( "gn4large_" + lib_name, "copy_g", in_type, out_type, 4); kernel_source += get_template_definition( "gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type); kernel_source += get_template_definition( "gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type); kernel_source += get_template_definition( "gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type); kernel_source += get_template_definition( "ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_dynamic_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition( "gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int"); kernel_source += get_template_definition( "gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type); kernel_source += get_template_definition( "gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type); kernel_source += get_template_definition( "gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type); kernel_source += get_template_definition( "ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, bool precise, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { std::string kernel_source = metal::utils(); auto in_type = get_type_string(out.dtype()); auto acc_type = get_type_string(precise ? float32 : out.dtype()); kernel_source += metal::softmax(); kernel_source += get_template_definition( "block_" + lib_name, "softmax_single_row", in_type, acc_type); kernel_source += get_template_definition( "looped_" + lib_name, "softmax_looped", in_type, acc_type); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_logsumexp_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { auto t_str = get_type_string(out.dtype()); std::string kernel_source; kernel_source = metal::utils(); kernel_source += metal::logsumexp(); kernel_source += get_template_definition("block_" + lib_name, "logsumexp", t_str); kernel_source += get_template_definition( "looped_" + lib_name, "logsumexp_looped", t_str); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, bool reverse, bool inclusive, const std::string& reduce_type, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto out_type = get_type_string(out.dtype()); std::string op = "Cum" + reduce_type + "<" + out_type + ">"; op[3] = toupper(op[3]); std::ostringstream kernel_source; kernel_source << metal::utils() << metal::scan(); const std::array, 2> scan_kernels = {{ {"contig_", "contiguous_scan"}, {"strided_", "strided_scan"}, }}; for (auto& [prefix, kernel] : scan_kernels) { kernel_source << get_template_definition( prefix + lib_name, kernel, get_type_string(in.dtype()), get_type_string(out.dtype()), op, in.itemsize() <= 4 ? 4 : 2, inclusive, reverse); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); kernel_source << metal::utils() << metal::sort(); for (bool is_argsort : {true, false}) { std::string bool_string = is_argsort ? "true" : "false"; std::string func_string = is_argsort ? "carg_" : "c_"; kernel_source << get_template_definition( func_string + lib_name, "block_sort", in_type, out_type, bool_string, bn, tn); kernel_source << get_template_definition( "n" + func_string + lib_name, "block_sort_nc", in_type, out_type, bool_string, bn, tn); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_mb_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& idx, int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::sort(); std::array, 3> kernel_types = { {{"sort_", "mb_block_sort"}, {"partition_", "mb_block_partition"}, {"merge_", "mb_block_merge"}}}; for (auto& [name, func] : kernel_types) { kernel_source << get_template_definition( name + lib_name, func, get_type_string(in.dtype()), get_type_string(idx.dtype()), "true", bn, tn); } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& out_type) { auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); auto out_t = get_type_string(out_type); std::string op = op_type + "<" + out_t + ">"; std::string kernel_source = metal::utils(); kernel_source += metal::reduce_utils(); kernel_source += metal::reduce(); kernel_source += get_template_definition(kernel_name, func_name, out_t, op); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& in_type, const Dtype& out_type, const std::string& idx_t, int ndim /* = -1 */, int bm /* = -1 */, int bn /* = -1 */) { auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); std::string op = op_type + "<" + out_t + ">"; std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::reduce_utils(), metal::reduce()); if (bm >= 0) { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn); } else if (ndim >= 0) { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t, ndim); } else { kernel_source += get_template_definition( kernel_name, func_name, in_t, out_t, op, idx_t); } return kernel_source; }); auto st = d.get_kernel(kernel_name, lib); return st; } MTL::ComputePipelineState* get_steel_gemm_fused_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_fused() << get_template_definition( lib_name, "gemm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << get_template_definition( lib_name, "gemm_splitk", get_type_string(in.dtype()), get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b, mn_aligned, k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool axbpy) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << get_template_definition( lib_name, axbpy ? "gemm_splitk_accum_axpby" : "gemm_splitk_accum", get_type_string(in.dtype()), get_type_string(out.dtype())); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) : "nomask_t"; auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_masked() << get_template_definition( lib_name, "block_masked_gemm", get_type_string(out.dtype()), out_mask_type, op_mask_type, bm, bn, bk, wm, wn, transpose_a, transpose_b, mn_aligned, k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_gemm_gather_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool rhs) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm(), metal::steel_gemm_gather(), get_template_definition( lib_name, rhs ? "gather_mm_rhs" : "gather_mm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm(), metal::steel_gemm_segmented(), get_template_definition( lib_name, "segmented_mm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) : "nomask_t"; auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemv_masked() << get_template_definition( lib_name, (transpose_mat) ? "gemv_t_masked" : "gemv_masked", get_type_string(out.dtype()), out_mask_type, op_mask_type, bm, bn, sm, sn, tm, tn, contiguous ? 0 : 1); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, const array& out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv() << get_template_definition( lib_name, "implicit_gemm_conv_2d", get_type_string(out.dtype()), bm, bn, bk, wm, wn, n_channel_specialization, small_filter); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_conv_3d_kernel( metal::Device& d, const std::string& kernel_name, const array& out, int bm, int bn, int bk, int wm, int wn, bool small_filter) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_3d() << get_template_definition( lib_name, "implicit_gemm_conv_3d", get_type_string(out.dtype()), bm, bn, bk, wm, wn, small_filter); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_general() << get_template_definition( lib_name, "implicit_gemm_conv_2d_general", get_type_string(out.dtype()), bm, bn, bk, wm, wn); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string& template_def) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; std::string kernel_string; kernel_source << metal::fft() << template_def; return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, const std::string& mode) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm(), metal::quantized_utils(), (mode == "affine") ? metal::quantized() : metal::fp_quantized(), template_def); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& x, int group_size, int bits, const std::string& mode, int bm, int bn, int bk, int wm, int wn, bool transpose) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); bool is_affine = mode == "affine"; concatenate( kernel_source, is_affine ? metal::quantized() : metal::fp_quantized(), get_template_definition( lib_name, (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"), get_type_string(x.dtype()), group_size, bits, bm, bn, bk, wm, wn, transpose)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm_nax() << metal::steel_gemm_fused_nax() << get_template_definition( lib_name, "gemm", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool rhs) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm_nax(), metal::steel_gemm_gather_nax(), get_template_definition( lib_name, rhs ? "gather_mm_rhs_nax" : "gather_mm_nax", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm_nax() << metal::steel_gemm_splitk_nax() << get_template_definition( lib_name, "gemm_splitk_nax", get_type_string(out.dtype()), bm, bn, bk, wm, wn, transpose_a, transpose_b); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, const std::string& mode) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm_nax(), metal::quantized_utils(), (mode == "affine") ? metal::quantized_nax() : metal::fp_quantized_nax(), template_def); return kernel_source; }); return d.get_kernel(kernel_name, lib); } MTL::ComputePipelineState* get_gather_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& x, int group_size, int bits, const std::string& mode, int bm, int bn, int bk, int wm, int wn, bool transpose) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::gemm_nax(), metal::quantized_utils()); bool is_affine = mode == "affine"; concatenate( kernel_source, is_affine ? metal::quantized_nax() : metal::fp_quantized_nax(), get_template_definition( lib_name, (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs_nax"), get_type_string(x.dtype()), group_size, bits, bm, bn, bk, wm, wn, transpose)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_attention_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& q, int bq, int bk, int bd, int wm, int wn, const array& m) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::steel_attention(), get_template_definition( lib_name, "attention", get_type_string(q.dtype()), bq, bk, bd, wm, wn, get_type_string(m.dtype()))); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_attention_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& q, int bq, int bk, int bd, int wm, int wn, const array& m) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::steel_attention_nax(), get_template_definition( lib_name, "attention_nax", get_type_string(q.dtype()), bq, bk, bd, wm, wn, get_type_string(m.dtype()))); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/kernels/CMakeLists.txt ================================================ set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h erf.h expm1f.h fp8.h logging.h utils.h) function(build_kernel_base TARGET SRCFILE DEPS) set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions -Wno-c++20-extensions) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() if(CMAKE_BUILD_TYPE STREQUAL "Debug" AND MLX_METAL_VERSION GREATER_EQUAL 320) set(METAL_FLAGS ${METAL_FLAGS} -fmetal-enable-logging) endif() if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") set(METAL_FLAGS ${METAL_FLAGS} "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") endif() add_custom_command( COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} OUTPUT ${TARGET}.air COMMENT "Building ${TARGET}.air" VERBATIM) endfunction(build_kernel_base) function(build_kernel KERNEL) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) cmake_path(GET KERNEL STEM TARGET) build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}") set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE) endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) build_kernel(gemv steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) build_kernel(scaled_dot_product_attention sdpa_vector.h) if(MLX_METAL_VERSION GREATER_EQUAL 320) build_kernel(fence) endif() set(STEEL_HEADERS steel/defines.h steel/utils.h steel/conv/conv.h steel/conv/loader.h steel/conv/loaders/loader_channel_l.h steel/conv/loaders/loader_channel_n.h steel/conv/loaders/loader_general.h steel/conv/kernels/steel_conv.h steel/conv/kernels/steel_conv_3d.h steel/conv/kernels/steel_conv_general.h steel/gemm/gemm.h steel/gemm/mma.h steel/gemm/loader.h steel/gemm/params.h steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_segmented.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h steel/utils/integral_constant.h) set(STEEL_ATTN_HEADERS steel/defines.h steel/utils.h steel/gemm/gemm.h steel/gemm/mma.h steel/gemm/loader.h steel/gemm/transforms.h steel/utils/type_traits.h steel/utils/integral_constant.h steel/attn/attn.h steel/attn/loader.h steel/attn/mma.h steel/attn/params.h steel/attn/transforms.h steel/attn/kernels/steel_attention.h) set(STEEL_NAX_HEADERS steel/defines.h steel/utils.h steel/gemm/params.h steel/gemm/transforms.h steel/gemm/nax.h steel/gemm/gemm_nax.h steel/utils/type_traits.h steel/utils/integral_constant.h steel/gemm/kernels/steel_gemm_fused_nax.h steel/gemm/kernels/steel_gemm_gather_nax.h steel/gemm/kernels/steel_gemm_splitk_nax.h) set(STEEL_NAX_ATTN_HEADERS steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h steel/utils/integral_constant.h steel/attn/params.h steel/attn/kernels/steel_attention_nax.h) if(NOT MLX_METAL_JIT) build_kernel(arange arange.h) build_kernel(binary binary.h binary_ops.h) build_kernel(binary_two binary_two.h) build_kernel(copy copy.h) build_kernel(fft fft.h fft/radix.h fft/readwrite.h) build_kernel( reduce atomic.h reduction/ops.h reduction/reduce_init.h reduction/reduce_all.h reduction/reduce_col.h reduction/reduce_row.h) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(fp_quantized fp4.h fp8.h fp_quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_3d ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL 26.2)) build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk_nax ${STEEL_NAX_HEADERS}) build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h ${STEEL_NAX_HEADERS}) build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS}) else() target_compile_definitions(mlx PRIVATE MLX_METAL_NO_NAX) endif() endif() add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib DEPENDS ${KERNEL_AIR} COMMENT "Building mlx.metallib" VERBATIM) add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib) add_dependencies(mlx mlx-metallib) # Install metallib include(GNUInstallDirs) install( FILES ${MLX_METAL_PATH}/mlx.metallib DESTINATION ${CMAKE_INSTALL_LIBDIR} COMPONENT metallib) ================================================ FILE: mlx/backend/metal/kernels/arange.h ================================================ // Copyright © 2023-2024 Apple Inc. template [[kernel]] void arange( constant const T& start, constant const T& step, device T* out, uint index [[thread_position_in_grid]]) { out[index] = start + index * step; } ================================================ FILE: mlx/backend/metal/kernels/arange.metal ================================================ // Copyright © 2023-2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/arange.h" #define instantiate_arange(tname, type) \ instantiate_kernel("arange" #tname, arange, type) instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) instantiate_arange(uint32, uint32_t) instantiate_arange(uint64, uint64_t) instantiate_arange(int8, int8_t) instantiate_arange(int16, int16_t) instantiate_arange(int32, int32_t) instantiate_arange(int64, int64_t) instantiate_arange(float16, half) instantiate_arange(float32, float) instantiate_arange(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/arg_reduce.metal ================================================ // Copyright © 2023 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" using namespace metal; template struct IndexValPair { uint32_t index; U val; }; template struct ArgMin { static constexpr constant U init = Limits::max; IndexValPair reduce(IndexValPair best, IndexValPair current) { if (best.val > current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { for (int i = 0; i < N; i++) { if (vals[i] < best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template struct ArgMax { static constexpr constant U init = Limits::min; IndexValPair reduce(IndexValPair best, IndexValPair current) { if (best.val < current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { for (int i = 0; i < N; i++) { if (vals[i] > best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { return IndexValPair{ simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; } template [[kernel]] void arg_reduce_general( const device T* in [[buffer(0)]], device uint32_t* out [[buffer(1)]], const constant int* shape [[buffer(2)]], const constant int64_t* in_strides [[buffer(3)]], const constant int64_t* out_strides [[buffer(4)]], const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], uint3 gid [[thread_position_in_grid]], uint3 gsize [[threads_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Shapes and strides *do not* contain the reduction axis. The reduction size // and stride are provided in axis_stride and axis_size. // // Note: in shape == out shape with this convention. // // The sketch of the kernel is as follows. // 1. Launch prod(shape) * thread_group_size threads. // 2. Loop ceildiv(axis_size / lsize) times // 3. Read input values // 4. Reduce among them and go to 3 // 4. Reduce in each simd_group // 6. Write in the thread local memory // 6. Reduce them across thread group // 7. Write the output without need for atomic Op op; // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; for (int i = 0; i < N_READS; i++) { vals[i] = (current_index < axis_size) ? *current_in : T(Op::init); current_index++; current_in += axis_stride; } best = op.template reduce_many(best, vals, offset); } // At this point we have reduced the axis into thread group best values so we // need to reduce across the thread group. // First per simd reduction. for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } // Write to the threadgroup memory if (simd_lane_id == 0) { local_data[simd_group_id] = best; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id != 0) { return; } // Read the appropriate value from local data and perform one simd reduction uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } // Finally write the output if (lid.x == 0) { out[out_idx] = best.index; } } // clang-format off #define instantiate_arg_reduce(name, itype) \ instantiate_kernel( \ "argmin_" #name, arg_reduce_general, itype, ArgMin) \ instantiate_kernel( \ "argmax_" #name, arg_reduce_general, itype, ArgMax) instantiate_arg_reduce(bool_, bool) instantiate_arg_reduce(uint8, uint8_t) instantiate_arg_reduce(uint16, uint16_t) instantiate_arg_reduce(uint32, uint32_t) instantiate_arg_reduce(uint64, uint64_t) instantiate_arg_reduce(int8, int8_t) instantiate_arg_reduce(int16, int16_t) instantiate_arg_reduce(int32, int32_t) instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float32, float) instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/atomic.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include using namespace metal; /////////////////////////////////////////////////////////////////////////////// // Atomic utils /////////////////////////////////////////////////////////////////////////////// #pragma METAL internals : enable template constexpr constant bool is_metal_atomic = _disjunction< is_same, is_same, is_same, is_same>::value; #pragma METAL internals : disable template struct mlx_atomic { atomic val; }; template struct mlx_atomic>> { atomic val; }; /////////////////////////////////////////////////////////////////////////////// // Native metal atomics /////////////////////////////////////////////////////////////////////////////// template , bool> = true> METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_or_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, size_t offset) { T expected = mlx_atomic_load_explicit(object, offset); while (!mlx_atomic_compare_exchange_weak_explicit( object, &expected, val * expected, offset)) { } } template , bool> = true> METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread T* expected, T val, size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, val, memory_order_relaxed, memory_order_relaxed); } // Specialization for float since it does not atomic_fetch_min_explicit template <> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, float val, size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val < expected) { if (mlx_atomic_compare_exchange_weak_explicit( object, &expected, val, offset)) { return; } } } // Specialization for float since it does not atomic_fetch_max_explicit template <> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, float val, size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val > expected) { if (mlx_atomic_compare_exchange_weak_explicit( object, &expected, val, offset)) { return; } } } /////////////////////////////////////////////////////////////////////////////// // Custom atomics /////////////////////////////////////////////////////////////////////////////// namespace { template constexpr constant uint packing_size = sizeof(uint) / sizeof(T); template union uint_or_packed { T val[packing_size]; uint bits; }; template struct mlx_atomic_update_helper { uint operator()(uint_or_packed init, T update, size_t elem_offset) { Op op; init.val[elem_offset] = op(update, init.val[elem_offset]); return init.bits; } }; template METAL_FUNC void mlx_atomic_update_and_store( device mlx_atomic* object, T update, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; mlx_atomic_update_helper helper; uint_or_packed expected; expected.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); while (Op::condition(update, expected.val[elem_offset]) && !mlx_atomic_compare_exchange_weak_explicit( object, &(expected.bits), helper(expected, update, elem_offset), pack_offset)) { } } template struct __None { static bool condition(T a, T b) { #pragma unused(a) #pragma unused(b) return true; } T operator()(T a, T b) { #pragma unused(b) return a; } }; template struct __Add { static bool condition(T a, T b) { #pragma unused(a) #pragma unused(b) return true; } T operator()(T a, T b) { return a + b; } }; template struct __Mul { static bool condition(T a, T b) { #pragma unused(a) return b != 0; } T operator()(T a, T b) { return a * b; } }; template struct __Max { static bool condition(T a, T b) { return a > b; } T operator()(T a, T b) { return max(a, b); } }; template struct __Min { static bool condition(T a, T b) { return a < b; } T operator()(T a, T b) { return min(a, b); } }; } // namespace template , bool> = true> METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { size_t pack_offset = offset / sizeof(T); size_t elem_offset = offset % sizeof(T); uint_or_packed packed_val; packed_val.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); return packed_val.val[elem_offset]; } template , bool> = true> METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = __UINT32_MAX__; identity.val[elem_offset] = val; atomic_fetch_and_explicit( &(object[pack_offset].val), identity.bits, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_or_explicit( device mlx_atomic* object, T val, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 0; identity.val[elem_offset] = val; atomic_fetch_or_explicit( &(object[pack_offset].val), identity.bits, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread uint* expected, uint val, size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, val, memory_order_relaxed, memory_order_relaxed); } ================================================ FILE: mlx/backend/metal/kernels/bf16.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include using namespace metal; typedef bfloat bfloat16_t; inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { return as_type(x); } inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { return as_type(x); } ================================================ FILE: mlx/backend/metal/kernels/bf16_math.h ================================================ // Copyright © 2023 Apple Inc. #pragma once /////////////////////////////////////////////////////////////////////////////// // Metal math for bfloat16 /////////////////////////////////////////////////////////////////////////////// /* Following the Metal Shading Language Specification (Metal 3.1) "bfloat is an extended itypeing point type that only allows implicit conversion to a type of greater itypeing point rank. While bfloat can be implicitly converted to itype, it cannot be implicitly converted to half, and neither itype nor half can be implicitly converted to bfloat." Further, as far as I can tell, the stdlib math/simd functions are not defined for bfloat and calling with an argument of type bfloat will result in that argument getting implicitly converted to itype which then returns an output that is (likely) a itype which cannot be implicitly converted into a bfloat This leads to situations where bfloat a = 5.0bf; bfloat b = metal::abs(a); // this will throw an error since abs return itype bfloat c = static_cast(metal::abs(a)); // this is fine For the moment, I will be adding overloaded instantiations of the math functions to accordingly automatically handle the casting */ #define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ \ METAL_FUNC otype abs(itype x) { \ return static_cast(__metal_fabs(static_cast(x), mfast)); \ } \ METAL_FUNC otype acos(itype x) { \ return static_cast(__metal_acos(static_cast(x), mfast)); \ } \ METAL_FUNC otype acosh(itype x) { \ return static_cast(__metal_acosh(static_cast(x), mfast)); \ } \ METAL_FUNC otype asin(itype x) { \ return static_cast(__metal_asin(static_cast(x), mfast)); \ } \ METAL_FUNC otype asinh(itype x) { \ return static_cast(__metal_asinh(static_cast(x), mfast)); \ } \ METAL_FUNC otype atan(itype y_over_x) { \ return static_cast( \ __metal_atan(static_cast(y_over_x), mfast)); \ } \ METAL_FUNC otype atan2(itype y, itype x) { \ return static_cast( \ __metal_atan2(static_cast(y), static_cast(x), mfast)); \ } \ METAL_FUNC otype atanh(itype x) { \ return static_cast(__metal_atanh(static_cast(x), mfast)); \ } \ METAL_FUNC otype ceil(itype x) { \ return static_cast(__metal_ceil(static_cast(x), mfast)); \ } \ METAL_FUNC otype cos(itype x) { \ return static_cast(__metal_cos(static_cast(x), mfast)); \ } \ METAL_FUNC otype cosh(itype x) { \ return static_cast(__metal_cosh(static_cast(x), mfast)); \ } \ METAL_FUNC otype cospi(itype x) { \ return static_cast(__metal_cospi(static_cast(x), mfast)); \ } \ METAL_FUNC otype divide(itype x, itype y) { \ return static_cast( \ __metal_divide(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype exp(itype x) { \ return static_cast(__metal_exp(static_cast(x), mfast)); \ } \ METAL_FUNC otype exp10(itype x) { \ return static_cast(__metal_exp10(static_cast(x), mfast)); \ } \ METAL_FUNC otype exp2(itype x) { \ return static_cast(__metal_exp2(static_cast(x), mfast)); \ } \ METAL_FUNC otype fabs(itype x) { \ return static_cast(__metal_fabs(static_cast(x), mfast)); \ } \ METAL_FUNC otype fdim(itype x, itype y) { \ ctype t = static_cast(x - y); \ return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ } \ METAL_FUNC otype floor(itype x) { \ return static_cast(__metal_floor(static_cast(x), mfast)); \ } \ METAL_FUNC otype fma(itype x, itype y, itype z) { \ return static_cast(__metal_fma( \ static_cast(x), static_cast(y), static_cast(z))); \ } \ METAL_FUNC otype fmax(itype x, itype y) { \ return static_cast( \ __metal_fmax(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ return static_cast(__metal_fmax3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ return static_cast(__metal_fmedian3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype fmin(itype x, itype y) { \ return static_cast( \ __metal_fmin(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ return static_cast(__metal_fmin3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype fmod(itype x, itype y) { \ return static_cast( \ __metal_fmod(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype fract(itype x) { \ return static_cast(__metal_fract(static_cast(x), mfast)); \ } \ METAL_FUNC otype frexp(itype x, thread int& exp) { \ return static_cast(__metal_frexp(static_cast(x), &exp)); \ } \ METAL_FUNC otype ldexp(itype x, int k) { \ return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ } \ METAL_FUNC otype log(itype x) { \ return static_cast(__metal_log(static_cast(x), mfast)); \ } \ METAL_FUNC otype log10(itype x) { \ return static_cast(__metal_log10(static_cast(x), mfast)); \ } \ METAL_FUNC otype log2(itype x) { \ return static_cast(__metal_log2(static_cast(x), mfast)); \ } \ METAL_FUNC otype max(itype x, itype y) { \ return static_cast( \ __metal_fmax(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype max3(itype x, itype y, itype z) { \ return static_cast(__metal_fmax3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype median3(itype x, itype y, itype z) { \ return static_cast(__metal_fmedian3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype min(itype x, itype y) { \ return static_cast( \ __metal_fmin(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype min3(itype x, itype y, itype z) { \ return static_cast(__metal_fmin3( \ static_cast(x), \ static_cast(y), \ static_cast(z), \ mfast)); \ } \ METAL_FUNC otype nextafter(itype x, itype y) { \ return static_cast( \ __metal_nextafter(static_cast(x), static_cast(y))); \ } \ METAL_FUNC otype pow(itype x, itype y) { \ return static_cast( \ __metal_pow(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype powr(itype x, itype y) { \ return static_cast( \ __metal_powr(static_cast(x), static_cast(y), mfast)); \ } \ METAL_FUNC otype rint(itype x) { \ return static_cast(__metal_rint(static_cast(x), mfast)); \ } \ METAL_FUNC otype round(itype x) { \ return static_cast(__metal_round(static_cast(x), mfast)); \ } \ METAL_FUNC otype rsqrt(itype x) { \ return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ } \ METAL_FUNC otype sin(itype x) { \ return static_cast(__metal_sin(static_cast(x), mfast)); \ } \ METAL_FUNC otype sinh(itype x) { \ return static_cast(__metal_sinh(static_cast(x), mfast)); \ } \ METAL_FUNC otype sinpi(itype x) { \ return static_cast(__metal_sinpi(static_cast(x), mfast)); \ } \ METAL_FUNC otype sqrt(itype x) { \ return static_cast(__metal_sqrt(static_cast(x), mfast)); \ } \ METAL_FUNC otype tan(itype x) { \ return static_cast(__metal_tan(static_cast(x), mfast)); \ } \ METAL_FUNC otype tanh(itype x) { \ return static_cast(__metal_tanh(static_cast(x), mfast)); \ } \ METAL_FUNC otype tanpi(itype x) { \ return static_cast(__metal_tanpi(static_cast(x), mfast)); \ } \ METAL_FUNC otype trunc(itype x) { \ return static_cast(__metal_trunc(static_cast(x), mfast)); \ } namespace metal { instantiate_metal_math_funcs( bfloat16_t, bfloat16_t, float, __METAL_MAYBE_FAST_MATH__); namespace fast { instantiate_metal_math_funcs( bfloat16_t, bfloat16_t, float, __METAL_FAST_MATH__); } // namespace fast namespace precise { instantiate_metal_math_funcs( bfloat16_t, bfloat16_t, float, __METAL_PRECISE_MATH__); } // namespace precise } // namespace metal /////////////////////////////////////////////////////////////////////////////// // Metal simd for bfloat16 /////////////////////////////////////////////////////////////////////////////// #define instantiate_metal_simd_comm_funcs( \ itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ \ METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ return ctype_to_otype( \ __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ } \ \ METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ return ctype_to_otype( \ __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ } \ \ METAL_FUNC otype simd_shuffle_and_fill_down( \ itype data, itype filling_data, ushort delta, ushort modulo) { \ return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ } \ \ METAL_FUNC otype simd_shuffle_and_fill_down( \ itype data, itype filling_data, ushort delta) { \ return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ itype_to_ctype(data), \ itype_to_ctype(filling_data), \ delta, \ __metal_get_simdgroup_size(ushort()))); \ } \ \ METAL_FUNC otype simd_shuffle_and_fill_up( \ itype data, itype filling_data, ushort delta, ushort modulo) { \ return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ } \ \ METAL_FUNC otype simd_shuffle_and_fill_up( \ itype data, itype filling_data, ushort delta) { \ return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ itype_to_ctype(data), \ itype_to_ctype(filling_data), \ delta, \ __metal_get_simdgroup_size(ushort()))); \ } \ \ METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ return ctype_to_otype( \ __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ } \ \ METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ return ctype_to_otype( \ __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ } \ \ METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ return ctype_to_otype( \ __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ } \ \ METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ return ctype_to_otype( \ __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ } \ \ METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ return ctype_to_otype( \ __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ } #define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ \ METAL_FUNC otype simd_max(itype data) { \ return static_cast(__metal_simd_max(static_cast(data))); \ } \ \ METAL_FUNC otype simd_min(itype data) { \ return static_cast(__metal_simd_min(static_cast(data))); \ } \ \ METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ return static_cast( \ __metal_simd_prefix_exclusive_product(static_cast(data))); \ } \ \ METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ return static_cast( \ __metal_simd_prefix_exclusive_sum(static_cast(data))); \ } \ \ METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ return static_cast( \ __metal_simd_prefix_inclusive_product(static_cast(data))); \ } \ \ METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ return static_cast( \ __metal_simd_prefix_inclusive_sum(static_cast(data))); \ } \ \ METAL_FUNC otype simd_product(itype data) { \ return static_cast(__metal_simd_product(static_cast(data))); \ } \ \ METAL_FUNC otype simd_sum(itype data) { \ return static_cast(__metal_simd_sum(static_cast(data))); \ } \ \ METAL_FUNC otype simd_xor(itype data) { \ return static_cast(__metal_simd_xor(static_cast(data))); \ } namespace metal { instantiate_metal_simd_comm_funcs( bfloat16_t, bfloat16_t, uint16_t, bfloat16_to_uint16, uint16_to_bfloat16); instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); } // namespace metal ================================================ FILE: mlx/backend/metal/kernels/binary.h ================================================ // Copyright © 2024 Apple Inc. template [[kernel]] void binary_ss( device const T* a, device const T* b, device U* c, uint index [[thread_position_in_grid]]) { c[index] = Op()(a[0], b[0]); } template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { c[index + i] = Op()(a[0], b[index + i]); } } else { for (int i = 0; i < N; ++i) { c[index + i] = Op()(a[0], b[index + i]); } } } template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { c[index + i] = Op()(a[index + i], b[0]); } } else { for (int i = 0; i < N; ++i) { c[index + i] = Op()(a[index + i], b[0]); } } } template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { c[index + i] = Op()(a[index + i], b[index + i]); } } else { for (int i = 0; i < N; ++i) { c[index + i] = Op()(a[index + i], b[index + i]); } } } template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { c[offset + i] = Op()(a[0], b[offset + i]); } } else { for (int i = 0; i < N; ++i) { c[offset + i] = Op()(a[0], b[offset + i]); } } } template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { c[offset + i] = Op()(a[offset + i], b[0]); } } else { for (int i = 0; i < N; ++i) { c[offset + i] = Op()(a[offset + i], b[0]); } } } template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { c[offset + i] = Op()(a[offset + i], b[offset + i]); } } else { for (int i = 0; i < N; ++i) { c[offset + i] = Op()(a[offset + i], b[offset + i]); } } } template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, device U* c, constant const int64_t& a_stride, constant const int64_t& b_stride, uint index [[thread_position_in_grid]]) { auto a_idx = elem_to_loc_1(index, a_stride); auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, device U* c, constant const int64_t a_strides[2], constant const int64_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, device U* c, constant const int64_t a_strides[3], constant const int64_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } template < typename T, typename U, typename Op, int N = 1, typename IdxT = int64_t> [[kernel]] void binary_g( device const T* a, device const T* b, device U* c, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); IdxT a_xstride = a_strides[ndim - 1]; IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { c[out_idx++] = Op()(a[idx.x], b[idx.y]); idx.x += a_xstride; idx.y += b_xstride; } } ================================================ FILE: mlx/backend/metal/kernels/binary.metal ================================================ // Copyright © 2024 Apple Inc. #include #include // clang-format off #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" #define instantiate_binary_work_per_thread(op, tname, itype, otype) \ instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \ #define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \ instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_all(op, tname, itype, otype) \ instantiate_binary_base(op, tname, itype, otype) \ instantiate_binary_work_per_thread(op, tname, itype, otype) #define instantiate_binary_integer(op) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_base(op, int64, int64_t, int64_t) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_integer(op) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t)\ instantiate_binary_float(op) #define instantiate_binary_types_bool(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \ instantiate_binary_base(op, uint64, uint64_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \ instantiate_binary_base(op, int64, int64_t, bool) \ instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ instantiate_binary_base(op, complex64, complex64_t, bool) instantiate_binary_types(Add) instantiate_binary_types(Divide) instantiate_binary_types_bool(Equal) instantiate_binary_types_bool(Greater) instantiate_binary_types_bool(GreaterEqual) instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) instantiate_binary_types(Subtract) instantiate_binary_types(Power) instantiate_binary_types(Remainder) instantiate_binary_float(ArcTan2) // NaNEqual only needed for floating point types with boolean output instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) instantiate_binary_base(NaNEqual, complex64, complex64_t, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool) // Bitwise ops only need integer types and bool (except for l/r shift) instantiate_binary_integer(BitwiseAnd) instantiate_binary_all(BitwiseAnd, bool_, bool, bool) instantiate_binary_integer(BitwiseOr) instantiate_binary_all(BitwiseOr, bool_, bool, bool) instantiate_binary_integer(BitwiseXor) instantiate_binary_all(BitwiseXor, bool_, bool, bool) instantiate_binary_integer(LeftShift) instantiate_binary_integer(RightShift) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/binary_ops.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include constant mlx::os_log logger("mlx", "binary_ops"); struct Add { template T operator()(T x, T y) { return x + y; } }; struct FloorDivide { template T operator()(T x, T y) { return x / y; } template <> float operator()(float x, float y) { return trunc(x / y); } template <> half operator()(half x, half y) { return trunc(x / y); } template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); } }; struct Divide { template T operator()(T x, T y) { return x / y; } }; struct Remainder { template metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { return x % y; } template metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { auto r = x % y; if (r != 0 && (r < 0 != y < 0)) { r += y; } return r; } template metal::enable_if_t, T> operator()(T x, T y) { T r = fmod(x, y); if (r != 0 && (r < 0 != y < 0)) { r += y; } return r; } template <> complex64_t operator()(complex64_t x, complex64_t y) { return x % y; } }; struct Equal { template bool operator()(T x, T y) { return x == y; } }; struct NaNEqual { template bool operator()(T x, T y) { return x == y || (metal::isnan(x) && metal::isnan(y)); } template <> bool operator()(complex64_t x, complex64_t y) { return x == y || (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && metal::isnan(y.imag)) || (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); } }; struct Greater { template bool operator()(T x, T y) { return x > y; } }; struct GreaterEqual { template bool operator()(T x, T y) { return x >= y; } }; struct Less { template bool operator()(T x, T y) { return x < y; } }; struct LessEqual { template bool operator()(T x, T y) { return x <= y; } }; struct LogAddExp { template T operator()(T x, T y) { if (metal::isnan(x) || metal::isnan(y)) { return metal::numeric_limits::quiet_NaN(); } constexpr T inf = metal::numeric_limits::infinity(); T maxval = metal::max(x, y); T minval = metal::min(x, y); return (minval == -inf || maxval == inf) ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; complex64_t operator()(complex64_t x, complex64_t y) { if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || metal::isnan(y.imag)) { return metal::numeric_limits::quiet_NaN(); } constexpr float inf = metal::numeric_limits::infinity(); complex64_t maxval = x > y ? x : y; complex64_t minval = x < y ? x : y; if (minval.real == -inf || maxval.real == inf) return maxval; float m = metal::exp(minval.real - maxval.real); complex64_t dexp{ m * metal::cos(minval.imag - maxval.imag), m * metal::sin(minval.imag - maxval.imag), }; return maxval + log1p(dexp); } }; struct Maximum { template metal::enable_if_t, T> operator()(T x, T y) { return metal::max(x, y); } template metal::enable_if_t, T> operator()(T x, T y) { if (metal::isnan(x)) { return x; } return x > y ? x : y; } template <> complex64_t operator()(complex64_t x, complex64_t y) { if (metal::isnan(x.real) || metal::isnan(x.imag)) { return x; } return x > y ? x : y; } }; struct Minimum { template metal::enable_if_t, T> operator()(T x, T y) { return metal::min(x, y); } template metal::enable_if_t, T> operator()(T x, T y) { if (metal::isnan(x)) { return x; } return x < y ? x : y; } template <> complex64_t operator()(complex64_t x, complex64_t y) { if (metal::isnan(x.real) || metal::isnan(x.imag)) { return x; } return x < y ? x : y; } }; struct Multiply { template T operator()(T x, T y) { return x * y; } }; struct NotEqual { template bool operator()(T x, T y) { return x != y; } template <> bool operator()(complex64_t x, complex64_t y) { return x.real != y.real || x.imag != y.imag; } }; struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { return metal::pow(base, exp); } template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; // Undefined to raise integer to negative power if (exp < 0) { logger.log_debug( "int pow exp<0 (base=%ld exp=%ld)", (long)base, (long)exp); return 0; } while (exp) { if (exp & 1) { res *= base; } exp >>= 1; base *= base; } return res; } template <> complex64_t operator()(complex64_t x, complex64_t y) { if (x.real == 0 && x.imag == 0) { if (metal::isnan(y.real) || metal::isnan(y.imag)) { auto nan = metal::numeric_limits::quiet_NaN(); return {nan, nan}; } return {0.0, 0.0}; } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto phase = y.imag * x_ln_r + y.real * x_theta; return {mag * metal::cos(phase), mag * metal::sin(phase)}; } }; struct Subtract { template T operator()(T x, T y) { return x - y; } }; struct LogicalAnd { template T operator()(T x, T y) { return x && y; }; }; struct LogicalOr { template T operator()(T x, T y) { return x || y; }; }; struct BitwiseAnd { template T operator()(T x, T y) { return x & y; }; }; struct BitwiseOr { template T operator()(T x, T y) { return x | y; }; }; struct BitwiseXor { template T operator()(T x, T y) { return x ^ y; }; }; struct LeftShift { template T operator()(T x, T y) { return x << y; }; }; struct RightShift { template T operator()(T x, T y) { return x >> y; }; }; struct ArcTan2 { template T operator()(T y, T x) { return metal::precise::atan2(y, x); } }; struct DivMod { template metal::array operator()(T x, T y) { return {FloorDivide{}(x, y), Remainder{}(x, y)}; }; }; ================================================ FILE: mlx/backend/metal/kernels/binary_two.h ================================================ // Copyright © 2024 Apple Inc. template [[kernel]] void binary_ss( device const T* a, device const T* b, device U* c, device U* d, uint index [[thread_position_in_grid]]) { auto out = Op()(a[0], b[0]); c[index] = out[0]; d[index] = out[1]; } template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { auto out = Op()(a[0], b[index + i]); c[index + i] = out[0]; d[index + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[0], b[index + i]); c[index + i] = out[0]; d[index + i] = out[1]; } } } template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { auto out = Op()(a[index + i], b[0]); c[index + i] = out[0]; d[index + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[index + i], b[0]); c[index + i] = out[0]; d[index + i] = out[1]; } } } template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { auto out = Op()(a[index + i], b[index + i]); c[index + i] = out[0]; d[index + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[index + i], b[index + i]); c[index + i] = out[0]; d[index + i] = out[1]; } } } template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { auto out = Op()(a[0], b[offset + i]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[0], b[offset + i]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } } template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { auto out = Op()(a[offset + i], b[0]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[offset + i], b[0]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } } template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { auto out = Op()(a[offset + i], b[offset + i]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } else { for (int i = 0; i < N; ++i) { auto out = Op()(a[offset + i], b[offset + i]); c[offset + i] = out[0]; d[offset + i] = out[1]; } } } template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, device U* c, device U* d, constant const int64_t& a_stride, constant const int64_t& b_stride, uint index [[thread_position_in_grid]]) { auto a_idx = elem_to_loc_1(index, a_stride); auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; } template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, device U* c, device U* d, constant const int64_t a_strides[2], constant const int64_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, device U* c, device U* d, constant const int64_t a_strides[3], constant const int64_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } template < typename T, typename U, typename Op, int N = 1, typename IdxT = int64_t> [[kernel]] void binary_g( device const T* a, device const T* b, device U* c, device U* d, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); IdxT a_xstride = a_strides[ndim - 1]; IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { auto out = Op()(a[idx.x], b[idx.y]); c[out_idx] = out[0]; d[out_idx++] = out[1]; idx.x += a_xstride; idx.y += b_xstride; } } ================================================ FILE: mlx/backend/metal/kernels/binary_two.metal ================================================ // Copyright © 2024 Apple Inc. #include #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" #define instantiate_binary_work_per_thread(op, tname, itype, otype) \ instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) #define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_all(op, tname, itype, otype) \ instantiate_binary_base(op, tname, itype, otype) \ instantiate_binary_work_per_thread(op, tname, itype, otype) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_base(op, int64, int64_t, int64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t) \ instantiate_binary_float(op) instantiate_binary_types(DivMod) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/cexpf.h ================================================ // Copyright © 2025 Apple Inc. // Copyright © 2008-2013 NVIDIA Corporation // Copyright © 2013 Filipe RNC Maia // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Forked from // https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h // TODO: We should use thrust::exp but the thrust header in old CUDA versions // can not be used in JIT. #pragma once #include using ieee_float_shape_type = union { float value; uint32_t word; }; inline void get_float_word(thread uint32_t& i, float d) { ieee_float_shape_type gf_u; gf_u.value = (d); (i) = gf_u.word; } inline void get_float_word(thread int32_t& i, float d) { ieee_float_shape_type gf_u; gf_u.value = (d); (i) = gf_u.word; } inline void set_float_word(thread float& d, uint32_t i) { ieee_float_shape_type sf_u; sf_u.word = (i); (d) = sf_u.value; } inline float frexp_expf(float x, thread int* expt) { const uint32_t k = 235; const float kln2 = 162.88958740F; float exp_x; uint32_t hx; exp_x = metal::exp(x - kln2); get_float_word(hx, exp_x); *expt = (hx >> 23) - (0x7f + 127) + k; set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); return exp_x; } inline complex64_t ldexp_cexpf(complex64_t z, int expt) { float x, y, exp_x, scale1, scale2; int ex_expt, half_expt; x = z.real; y = z.imag; exp_x = frexp_expf(x, &ex_expt); expt += ex_expt; half_expt = expt / 2; set_float_word(scale1, (0x7f + half_expt) << 23); half_expt = expt - half_expt; set_float_word(scale2, (0x7f + half_expt) << 23); return complex64_t{ metal::cos(y) * exp_x * scale1 * scale2, metal::sin(y) * exp_x * scale1 * scale2}; } inline complex64_t cexpf(const thread complex64_t& z) { float x, y, exp_x; uint32_t hx, hy; const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; x = z.real; y = z.imag; get_float_word(hy, y); hy &= 0x7fffffff; /* cexp(x + I 0) = exp(x) + I 0 */ if (hy == 0) { return complex64_t{metal::exp(x), y}; } get_float_word(hx, x); /* cexp(0 + I y) = cos(y) + I sin(y) */ if ((hx & 0x7fffffff) == 0) { return complex64_t{metal::cos(y), metal::sin(y)}; } if (hy >= 0x7f800000) { if ((hx & 0x7fffffff) != 0x7f800000) { /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ return complex64_t{y - y, y - y}; } else if (hx & 0x80000000) { /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ return complex64_t{0.0, 0.0}; } else { /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ return complex64_t{x, y - y}; } } if (hx >= exp_ovfl && hx <= cexp_ovfl) { /* * x is between 88.7 and 192, so we must scale to avoid * overflow in expf(x). */ return ldexp_cexpf(z, 0); } else { /* * Cases covered here: * - x < exp_ovfl and exp(x) won't overflow (common case) * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 * - x = +-Inf (generated by exp()) * - x = NaN (spurious inexact exception from y) */ exp_x = metal::exp(x); return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; } } ================================================ FILE: mlx/backend/metal/kernels/complex.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include using namespace metal; struct complex64_t; template static constexpr constant bool can_convert_to_complex64 = !is_same_v && is_convertible_v; template static constexpr constant bool can_convert_from_complex64 = !is_same_v && (is_convertible_v || is_convertible_v); struct complex64_t { float real; float imag; // Constructors constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; constexpr complex64_t() : real(0), imag(0) {}; constexpr complex64_t() threadgroup : real(0), imag(0) {}; // Conversions to complex64_t template < typename T, typename = typename enable_if>::type> constexpr complex64_t(T x) thread : real(x), imag(0) {} template < typename T, typename = typename enable_if>::type> constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} template < typename T, typename = typename enable_if>::type> constexpr complex64_t(T x) device : real(x), imag(0) {} template < typename T, typename = typename enable_if>::type> constexpr complex64_t(T x) constant : real(x), imag(0) {} // Conversions from complex64_t template < typename T, typename = typename enable_if>::type> constexpr operator T() const thread { return static_cast(real); } template < typename T, typename = typename enable_if>::type> constexpr operator T() const threadgroup { return static_cast(real); } template < typename T, typename = typename enable_if>::type> constexpr operator T() const device { return static_cast(real); } template < typename T, typename = typename enable_if>::type> constexpr operator T() const constant { return static_cast(real); } }; constexpr complex64_t operator-(complex64_t x) { return {-x.real, -x.imag}; } constexpr bool operator>=(complex64_t a, complex64_t b) { return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); } constexpr bool operator>(complex64_t a, complex64_t b) { return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); } constexpr bool operator<=(complex64_t a, complex64_t b) { return operator>=(b, a); } constexpr bool operator<(complex64_t a, complex64_t b) { return operator>(b, a); } constexpr bool operator==(complex64_t a, complex64_t b) { return a.real == b.real && a.imag == b.imag; } constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) { a.real += b.real; a.imag += b.imag; return a; } constexpr threadgroup complex64_t& operator+=( threadgroup complex64_t& a, complex64_t b) { a.real += b.real; a.imag += b.imag; return a; } constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) { a.real += b.real; a.imag += b.imag; return a; } constexpr complex64_t operator+(float a, complex64_t b) { return {a + b.real, b.imag}; } constexpr complex64_t operator+(complex64_t a, float b) { return {a.real + b, a.imag}; } constexpr complex64_t operator-(complex64_t a, complex64_t b) { return {a.real - b.real, a.imag - b.imag}; } constexpr complex64_t operator-(float a, complex64_t b) { return {a - b.real, -b.imag}; } constexpr complex64_t operator-(complex64_t a, float b) { return {a.real - b, a.imag}; } constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; } constexpr complex64_t operator/(complex64_t a, complex64_t b) { auto denom = b.real * b.real + b.imag * b.imag; auto x = a.real * b.real + a.imag * b.imag; auto y = a.imag * b.real - a.real * b.imag; return {x / denom, y / denom}; } constexpr complex64_t operator/(float a, complex64_t b) { auto denom = b.real * b.real + b.imag * b.imag; auto x = a * b.real; auto y = -a * b.imag; return {x / denom, y / denom}; } constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); if (real != 0 && (real < 0 != b.real < 0)) { real += b.real; } if (imag != 0 && (imag < 0 != b.imag < 0)) { imag += b.imag; } return {real, imag}; } ================================================ FILE: mlx/backend/metal/kernels/conv.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/utils.h" #define MLX_MTL_CONST static constant constexpr const using namespace metal; /////////////////////////////////////////////////////////////////////////////// /// Naive unfold with dilation /////////////////////////////////////////////////////////////////////////////// template [[kernel]] void naive_unfold_Nd( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { int filter_size = params->C; for (short i = 0; i < N; i++) filter_size *= params->wS[i]; int out_pixels = 1; for (short i = 0; i < N; i++) out_pixels *= params->oS[i]; // Set out out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C); // Coordinates in input int is[N] = {0}; // gid.z: N oS (Batch and row in unfolded output) // gid.y: wS (Filter location to unfold input) // gid.x: C (channel) int n = (gid.z) / out_pixels; int oS = (gid.z) % out_pixels; int wS = gid.y; bool valid = n < params->N; // Unroll dimensions for (int i = N - 1; i >= 0; --i) { int os_ = (oS % params->oS[i]); int ws_ = (wS % params->wS[i]); ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; int is_max = 1 + params->idil[i] * (params->iS[i] - 1); valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); is[i] = is_ / params->idil[i]; oS /= params->oS[i]; wS /= params->wS[i]; } if (valid) { size_t in_offset = n * params->in_strides[0]; for (int i = 0; i < N; ++i) { in_offset += is[i] * params->in_strides[i + 1]; } out[gid.x] = in[in_offset + gid.x]; } else { out[gid.x] = T(0); } } // This kernel unfolds the input array of size (N, *spatial_dims, C) // into an array of size (N x *spatial_dims, C x *kernel_dims). template [[kernel]] void naive_unfold_transpose_Nd( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { int filter_size = params->C; for (short i = 0; i < N; i++) filter_size *= params->wS[i]; int out_pixels = 1; for (short i = 0; i < N; i++) out_pixels *= params->oS[i]; // Set out out += (size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C); // Coordinates in input int is[N] = {0}; // gid.z: N oS (Batch and row in unfolded output) // gid.y: wS (Filter location to unfold input) // gid.x: C (channel) int n = (gid.z) / out_pixels; int oS = (gid.z) % out_pixels; int wS = gid.y; bool valid = n < params->N; // Unroll dimensions int kernel_stride = 1; for (int i = N - 1; i >= 0; --i) { int os_ = (oS % params->oS[i]); int ws_ = (wS % params->wS[i]); out += ws_ * kernel_stride; ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; int is_max = 1 + params->idil[i] * (params->iS[i] - 1); valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); is[i] = is_ / params->idil[i]; oS /= params->oS[i]; wS /= params->wS[i]; kernel_stride *= params->wS[i]; } if (valid) { size_t in_offset = n * params->in_strides[0]; for (int i = 0; i < N; ++i) { in_offset += is[i] * params->in_strides[i + 1]; } out[0] = in[in_offset + gid.x]; } else { out[0] = T(0); } } #define instantiate_naive_unfold_nd(name, itype, n) \ template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \ naive_unfold_Nd( \ const device itype* in [[buffer(0)]], \ device itype* out [[buffer(1)]], \ const constant MLXConvParams* params [[buffer(2)]], \ uint3 gid [[thread_position_in_grid]]); \ template \ [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \ naive_unfold_transpose_Nd( \ const device itype* in [[buffer(0)]], \ device itype* out [[buffer(1)]], \ const constant MLXConvParams* params [[buffer(2)]], \ uint3 gid [[thread_position_in_grid]]); #define instantiate_naive_unfold_nd_dims(name, itype) \ instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \ name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3) instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); /////////////////////////////////////////////////////////////////////////////// /// Depthwise convolution kernels /////////////////////////////////////////////////////////////////////////////// constant int ker_h [[function_constant(00)]]; constant int ker_w [[function_constant(01)]]; constant int str_h [[function_constant(10)]]; constant int str_w [[function_constant(11)]]; constant int tgp_h [[function_constant(100)]]; constant int tgp_w [[function_constant(101)]]; constant bool do_flip [[function_constant(200)]]; constant int span_h = tgp_h * str_h + ker_h - 1; constant int span_w = tgp_w * str_w + ker_w - 1; constant int span_hw = span_h * span_w; template [[kernel]] void depthwise_conv_2d( const device T* in [[buffer(0)]], const device T* wt [[buffer(1)]], device T* out [[buffer(2)]], const constant MLXConvParams<2>& params [[buffer(3)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 gid [[thread_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int tc = 8; constexpr int tw = 8; constexpr int th = 4; constexpr int c_per_thr = 8; constexpr int TGH = th * 2 + 6; constexpr int TGW = tw * 2 + 6; constexpr int TGC = tc; threadgroup T ins[TGH * TGW * TGC]; const int n_tgblocks_h = params.oS[0] / th; const int n = tid.z / n_tgblocks_h; const int tghid = tid.z % n_tgblocks_h; const int oh = tghid * th + lid.z; const int ow = gid.y; const int c = gid.x; in += n * params.in_strides[0]; // Load in { constexpr int n_threads = th * tw * tc; const int tg_oh = (tghid * th) * str_h - params.pad[0]; const int tg_ow = (tid.y * tw) * str_w - params.pad[1]; const int tg_c = tid.x * tc; const int thread_idx = simd_gid * 32 + simd_lid; constexpr int thr_per_hw = tc / c_per_thr; constexpr int hw_per_group = n_threads / thr_per_hw; const int thr_c = thread_idx % thr_per_hw; const int thr_hw = thread_idx / thr_per_hw; for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) { const int h = hw / span_w; const int w = hw % span_w; const int ih = tg_oh + h; const int iw = tg_ow + w; const int in_s_offset = h * span_w * TGC + w * TGC; if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { const auto in_load = in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c; MLX_MTL_PRAGMA_UNROLL for (int cc = 0; cc < c_per_thr; ++cc) { ins[in_s_offset + c_per_thr * thr_c + cc] = in_load[c_per_thr * thr_c + cc]; } } else { MLX_MTL_PRAGMA_UNROLL for (int cc = 0; cc < c_per_thr; ++cc) { ins[in_s_offset + c_per_thr * thr_c + cc] = T(0); } } } } threadgroup_barrier(mem_flags::mem_threadgroup); wt += c * params.wt_strides[0]; const auto ins_ptr = &ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x]; float o = 0.; for (int h = 0; h < ker_h; ++h) { for (int w = 0; w < ker_w; ++w) { int wt_h = h; int wt_w = w; if (do_flip) { wt_h = ker_h - h - 1; wt_w = ker_w - w - 1; } auto inv = ins_ptr[h * span_w * TGC + w * TGC]; auto wtv = wt[wt_h * ker_w + wt_w]; o += inv * wtv; } } threadgroup_barrier(mem_flags::mem_none); out += n * params.out_strides[0] + oh * params.out_strides[1] + ow * params.out_strides[2]; out[c] = static_cast(o); } #define instantiate_depthconv2d(iname, itype) \ instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype) instantiate_depthconv2d(float32, float); instantiate_depthconv2d(float16, half); instantiate_depthconv2d(bfloat16, bfloat16_t); template [[kernel]] void depthwise_conv_1d( const device T* in [[buffer(0)]], const device T* w [[buffer(1)]], device T* out [[buffer(2)]], constant const IdxT strides[3], constant const int& kernel_size, uint3 tid [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { out += (tid.z * static_cast(grid_dim.y) + tid.y) * grid_dim.x + tid.x; in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2]; w += tid.x * kernel_size; float acc = 0.0; for (int i = 0; i < kernel_size; ++i) { acc += static_cast(in[0]) * w[i]; in += strides[1]; } *out = static_cast(acc); } #define instantiate_depthconv1d(iname, itype) \ instantiate_kernel( \ "depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \ instantiate_kernel( \ "depthwise_conv_1d_" #iname "_large", \ depthwise_conv_1d, \ itype, \ int64_t) instantiate_depthconv1d(float32, float); instantiate_depthconv1d(float16, half); instantiate_depthconv1d(bfloat16, bfloat16_t); /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels /////////////////////////////////////////////////////////////////////////////// template struct WinogradTransforms {}; template <> struct WinogradTransforms<6, 3, 8> { MLX_MTL_CONST int OUT_TILE_SIZE = 6; MLX_MTL_CONST int FILTER_SIZE = 3; MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, }; MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, }; MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { {1.00, 0.00, 0.00}, {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00}, {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00}, {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0}, {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0}, {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0}, {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0}, {0.00, 0.00, 1.00}, }; }; constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; template [[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform( const device T* wt_in [[buffer(0)]], device T* wt_out [[buffer(1)]], const constant int& C [[buffer(2)]], const constant int& O [[buffer(3)]], uint tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { using WGT = WinogradTransforms; // Get lane position in simdgroup const short qid = simd_lane_id / 4; const short sm = (qid & 4) + (simd_lane_id / 2) % 4; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize G matrix simdgroup_matrix G; G.thread_elements()[0] = WGT::wt_transform[sm][sn]; G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; // Initialize Gt matrix simdgroup_matrix Gt; Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; // Move to the correct output filter size_t ko = BO * tid + simd_group_id; wt_in += ko * R * R * C; // wt_out is stored transposed (A x A x C x O) short ohw_0 = sm * 8 + sn; short ohw_1 = sm * 8 + sn + 1; device T* wt_out_0 = wt_out + ohw_0 * C * O + ko; device T* wt_out_1 = wt_out + ohw_1 * C * O + ko; // Prepare shared memory threadgroup T Ws[BO][R][R][BC]; // Loop over C for (int bc = 0; bc < C; bc += BC) { threadgroup_barrier(mem_flags::mem_threadgroup); // Read into shared memory for (int kh = 0; kh < R; ++kh) { for (int kw = 0; kw < R; ++kw) { for (int kc = simd_lane_id; kc < BC; kc += 32) { Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = 0; c < BC; ++c) { simdgroup_matrix g; g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); simdgroup_matrix g_out = (G * g) * Gt; wt_out_0[c * O] = static_cast(g_out.thread_elements()[0]); wt_out_1[c * O] = static_cast(g_out.thread_elements()[1]); } wt_in += BC; wt_out_0 += BC * O; wt_out_1 += BC * O; } } #define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ template [[host_name( \ "winograd_conv_2d_weight_transform_" #name "_bc" #bc)]] [[kernel]] void \ winograd_conv_2d_weight_transform( \ const device itype* wt_in [[buffer(0)]], \ device itype* wt_out [[buffer(1)]], \ const constant int& C [[buffer(2)]], \ const constant int& O [[buffer(3)]], \ uint tid [[threadgroup_position_in_grid]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); template [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform( const device T* inp_in [[buffer(0)]], device T* inp_out [[buffer(1)]], const constant MLXConvParams<2>& params [[buffer(2)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_per_grid [[threadgroups_per_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { (void)lid; using WGT = WinogradTransforms; constexpr int A = WGT::IN_TILE_SIZE; constexpr int N_SIMD_GROUPS = WM * WN; // Get lane position in simdgroup const short qid = simd_lane_id / 4; const short sm = (qid & 4) + (simd_lane_id / 2) % 4; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize B matrix simdgroup_matrix B; B.thread_elements()[0] = WGT::in_transform[sm][sn]; B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; // Initialize Bt matrix simdgroup_matrix Bt; Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; // Resolve input tile constexpr int TH = (A / WM); constexpr int TW = (A / WN); int kh = TH * (simd_group_id / WN); int kw = TW * (simd_group_id % WN); int bh = M * tid.y + kh; int bw = M * tid.x + kw; // Move to the correct input tile inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + bw * params.in_strides[2]; // Pre compute strides int jump_in[TH][TW]; for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2]; } } // inp_out is stored interleaved (A x A x tiles x C) size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t ohw_0 = sm * 8 + sn; size_t ohw_1 = sm * 8 + sn + 1; device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; // Prepare shared memory threadgroup T Is[A][A][BC]; // Loop over C for (int bc = 0; bc < params.C; bc += BC) { threadgroup_barrier(mem_flags::mem_threadgroup); // Read into shared memory for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { const device T* in_ptr = inp_in + jump_in[h][w]; for (int c = simd_lane_id; c < BC; c += 32) { Is[kh + h][kw + w][c] = in_ptr[c]; } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { simdgroup_matrix I; I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[1] = Is[sm][sn + 1][c]; simdgroup_matrix I_out = (Bt * I) * B; inp_out_0[c] = static_cast(I_out.thread_elements()[0]); inp_out_1[c] = static_cast(I_out.thread_elements()[1]); } inp_in += BC; inp_out_0 += BC; inp_out_1 += BC; } } #define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ template [[host_name( \ "winograd_conv_2d_input_transform_" #name "_bc" #bc)]] [[kernel]] void \ winograd_conv_2d_input_transform( \ const device itype* inp_in [[buffer(0)]], \ device itype* inp_out [[buffer(1)]], \ const constant MLXConvParams<2>& params [[buffer(2)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint3 tgp_per_grid [[threadgroups_per_grid]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); template [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform( const device T* out_in [[buffer(0)]], device T* out_out [[buffer(1)]], const constant MLXConvParams<2>& params [[buffer(2)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_per_grid [[threadgroups_per_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { (void)lid; using WGT = WinogradTransforms; constexpr int N_SIMD_GROUPS = WM * WN; // Get lane position in simdgroup const short qid = simd_lane_id / 4; const short sm = (qid & 4) + (simd_lane_id / 2) % 4; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize A matrix simdgroup_matrix B; B.thread_elements()[0] = WGT::out_transform[sm][sn]; B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; // Initialize At matrix simdgroup_matrix Bt; Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; // Out_in comes in shape (A x A x tiles x O) // We do transform and then write out to out_out in shape (N, H, W, O) // Resolve output tile constexpr int TH = (M / WM); constexpr int TW = (M / WN); int kh = TH * (simd_group_id / WN); int kw = TW * (simd_group_id % WN); int bh = M * tid.y + kh; int bw = M * tid.x + kw; // Move to the correct input tile out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] + bw * params.out_strides[2]; // Pre compute strides int jump_in[TH][TW]; for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; } } // out_in is stored interleaved (A x A x tiles x O) size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t ohw_0 = sm * 8 + sn; size_t ohw_1 = sm * 8 + sn + 1; const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; // Prepare shared memory threadgroup T Os[M][M][BO]; // Loop over O for (int bo = 0; bo < params.O; bo += BO) { threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { simdgroup_matrix O_mat; O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[1] = out_in_1[c]; simdgroup_matrix O_out = (Bt * (O_mat * B)); if ((sm < M) && (sn < M)) { Os[sm][sn][c] = static_cast(O_out.thread_elements()[0]); } if ((sm < M) && ((sn + 1) < M)) { Os[sm][sn + 1][c] = static_cast(O_out.thread_elements()[1]); } } threadgroup_barrier(mem_flags::mem_threadgroup); // Read out from shared memory for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { if (jump_in[h][w] >= 0) { device T* out_ptr = out_out + jump_in[h][w]; for (int c = simd_lane_id; c < BO; c += 32) { out_ptr[c] = Os[kh + h][kw + w][c]; } } } } out_out += BO; out_in_0 += BO; out_in_1 += BO; } } #define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ template [[host_name( \ "winograd_conv_2d_output_transform_" #name "_bo" #bo)]] [[kernel]] void \ winograd_conv_2d_output_transform( \ const device itype* out_in [[buffer(0)]], \ device itype* out_out [[buffer(1)]], \ const constant MLXConvParams<2>& params [[buffer(2)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint3 tgp_per_grid [[threadgroups_per_grid]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); // clang-format off #define instantiate_winograd_conv_2d(name, itype) \ instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on // clang-format off instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(bfloat16, bfloat16_t); instantiate_winograd_conv_2d(float16, half); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/copy.h ================================================ // Copyright © 2024 Apple Inc. template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { dst[index + i] = static_cast(src[0]); } } else { for (int i = 0; i < N; ++i) { dst[index + i] = static_cast(src[0]); } } } template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { dst[index + i] = static_cast(src[index + i]); } } else { for (int i = 0; i < N; ++i) { dst[index + i] = static_cast(src[index + i]); } } } template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { dst[offset + i] = static_cast(src[0]); } } else { for (int i = 0; i < N; ++i) { dst[offset + i] = static_cast(src[0]); } } } template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { dst[offset + i] = static_cast(src[offset + i]); } } else { for (int i = 0; i < N; ++i) { dst[offset + i] = static_cast(src[offset + i]); } } } template [[kernel]] void copy_g_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } template [[kernel]] void copy_g_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_2(index, src_strides); IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_g_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); IdxT dst_idx = index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int* src_shape [[buffer(2)]], constant const int64_t* src_strides [[buffer(3)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc( {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); if (N == 1) { IdxT dst_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); return; } auto xshape = src_shape[ndim - 1]; IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); auto src_xstride = src_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[dst_idx + i] = static_cast(src[src_idx]); src_idx += src_xstride; } } template [[kernel]] void copy_gg_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_1(index, src_stride); auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_gg_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_2(index, src_strides); auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_gg_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int* src_shape [[buffer(2)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, dst_strides, ndim); if (N == 1) { dst[idx.y] = static_cast(src[idx.x]); return; } IdxT src_xstride = src_strides[ndim - 1]; IdxT dst_xstride = dst_strides[ndim - 1]; auto xshape = src_shape[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[idx.y] = static_cast(src[idx.x]); idx.x += src_xstride; idx.y += dst_xstride; } } template [[kernel]] void copy_gg_dynamic_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], constant const int64_t& src_offset [[buffer(6)]], constant const int64_t& dst_offset [[buffer(7)]], uint index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_1(index, src_stride); auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx + dst_offset] = src[src_idx + src_offset]; } template [[kernel]] void copy_gg_dynamic_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t& src_offset [[buffer(6)]], constant const int64_t& dst_offset [[buffer(7)]], uint2 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_2(index, src_strides); auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx + dst_offset] = src[src_idx + src_offset]; } template [[kernel]] void copy_gg_dynamic_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t& src_offset [[buffer(6)]], constant const int64_t& dst_offset [[buffer(7)]], uint3 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx + dst_offset] = src[src_idx + src_offset]; } template [[kernel]] void copy_gg_dynamic( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int* src_shape [[buffer(2)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], constant const int64_t& src_offset [[buffer(6)]], constant const int64_t& dst_offset [[buffer(7)]], uint3 index [[thread_position_in_grid]]) { src += src_offset; dst += dst_offset; auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, dst_strides, ndim); if (N == 1) { dst[idx.y] = src[idx.x]; return; } IdxT src_xstride = src_strides[ndim - 1]; IdxT dst_xstride = dst_strides[ndim - 1]; auto xshape = src_shape[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[idx.y] = src[idx.x]; idx.x += src_xstride; idx.y += dst_xstride; } } ================================================ FILE: mlx/backend/metal/kernels/copy.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" #define instantiate_copy_work_per_thread(tname, itype, otype) \ instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("vn_copy" #tname, copy_v, itype, otype) #define instantiate_copy_base(tname, itype, otype) \ instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \ instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) #define instantiate_copy_all(tname, itype, otype) \ instantiate_copy_base(tname, itype, otype) \ instantiate_copy_work_per_thread(tname, itype, otype) #define instantiate_copy_same(tname, type) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \ instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \ instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \ instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \ instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \ instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \ instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \ instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \ instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \ instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \ instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \ instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \ instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \ instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4) #define instantiate_copy_itype(itname, itype) \ instantiate_copy_same(itname ##itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \ instantiate_copy_base(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \ instantiate_copy_base(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_base(itname ##complex64, itype, complex64_t) instantiate_copy_itype(bool_, bool) instantiate_copy_itype(uint8, uint8_t) instantiate_copy_itype(uint16, uint16_t) instantiate_copy_itype(uint32, uint32_t) instantiate_copy_itype(uint64, uint64_t) instantiate_copy_itype(int8, int8_t) instantiate_copy_itype(int16, int16_t) instantiate_copy_itype(int32, int32_t) instantiate_copy_itype(int64, int64_t) instantiate_copy_itype(float16, half) instantiate_copy_itype(float32, float) instantiate_copy_itype(bfloat16, bfloat16_t) instantiate_copy_itype(complex64, complex64_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/defines.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #if defined __METAL__ || defined MLX_METAL_JIT #define MTL_CONST constant #else #define MTL_CONST #endif static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int REDUCE_N_READS = 4; static MTL_CONST constexpr int REDUCE_N_WRITES = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; // Instantiate a templated kernel. // Extra args are used as template parameters: // e.g. instantiate_kernel(binary_int, binary, a, b) -> // [[host_name(binary_int)]] [kernel] binary #define instantiate_kernel(name, func, ...) \ template [[host_name( \ name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; ================================================ FILE: mlx/backend/metal/kernels/erf.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/backend/metal/kernels/expm1f.h" /* * Approximation to the error function. * Based on code from: * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 */ float erf(float a) { float r, s, t, u; t = metal::abs(a); s = a * a; if (t > 0.927734375f) { // maximum error 0.99527 ulp r = metal::fma( -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 u = metal::fma( -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 r = metal::fma(r, s, u); r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 r = metal::fma(r, t, -t); r = -expm1f(r); r = metal::copysign(r, a); } else { // maximum error 0.98929 ulp r = -5.96761703e-4f; // -0x1.38e000p-11 r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 r = metal::fma(r, a, a); } return r; } float erfinv(float a) { auto t = metal::fma(a, 0.0f - a, 1.0f); t = metal::log(t); float p; if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 p = 3.03697567e-10f; // 0x1.4deb44p-32 p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 } else { // maximum ulp error = 2.35002 p = 5.43877832e-9f; // 0x1.75c000p-28 p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 } return a * p; } ================================================ FILE: mlx/backend/metal/kernels/expm1f.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include // Original license copied below: // Copyright (c) 2015-2023 Norbert Juffa // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions // are met: // // 1. Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // // 2. Redistributions in binary form must reproduce the above copyright // notice, this list of conditions and the following disclaimer in the // documentation and/or other materials provided with the distribution. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. /* Compute exponential base e minus 1. Maximum ulp error = 0.997458 i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) */ float expm1f_scaled_unchecked(float a, float b) { float f, j, r, s, t, u, v, x, y; int i; // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 j = j - 12582912.0f; // 0x1.8p23 i = (int)j; f = fma(j, -6.93145752e-1f, a); // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] s = f * f; if (a == 0.0f) s = a; // ensure -0 is passed through // err = 0.997458 ulp1 = 11081805 r = 1.97350979e-4f; // 0x1.9de000p-13 r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 u = (j == 1) ? (f + 0.5f) : f; v = fma(r, s, u); s = 0.5f * b; t = ldexp(s, i); y = t - s; x = (t - y) - s; // double-float canonicalization of difference r = fma(v, t, x) + y; r = r + r; if (j == 0) r = v; if (j == 1) r = v + v; return r; } /* Compute exponential base e minus 1. max ulp err = 0.99746 */ float expm1f(float a) { float r; r = expm1f_scaled_unchecked(a, 1.0f); /* handle severe overflow and underflow */ if (abs(a - 1.0f) > 88.0f) { r = pow(2, a); r = fma(r, r, -1.0f); } return r; } ================================================ FILE: mlx/backend/metal/kernels/fence.metal ================================================ // Copyright © 2024 Apple Inc. #pragma METAL internals : enable #ifndef __METAL_MEMORY_SCOPE_SYSTEM__ #define __METAL_MEMORY_SCOPE_SYSTEM__ 3 #endif namespace metal { constexpr constant metal::thread_scope thread_scope_system = static_cast(__METAL_MEMORY_SCOPE_SYSTEM__); } #include [[kernel]] void input_coherent( volatile coherent(system) device uint* input [[buffer(0)]], const constant uint& size [[buffer(1)]], uint index [[thread_position_in_grid]]) { if (index < size) { input[index] = input[index]; } metal::atomic_thread_fence( metal::mem_flags::mem_device, metal::memory_order_seq_cst, metal::thread_scope_system); } // single thread kernel to update timestamp [[kernel]] void fence_update( volatile coherent(system) device uint* timestamp [[buffer(0)]], constant uint& value [[buffer(1)]]) { timestamp[0] = value; metal::atomic_thread_fence( metal::mem_flags::mem_device, metal::memory_order_seq_cst, metal::thread_scope_system); } // single thread kernel to spin wait for timestamp value [[kernel]] void fence_wait( volatile coherent(system) device uint* timestamp [[buffer(0)]], constant uint& value [[buffer(1)]]) { while (1) { metal::atomic_thread_fence( metal::mem_flags::mem_device, metal::memory_order_seq_cst, metal::thread_scope_system); if (timestamp[0] >= value) { break; } } } ================================================ FILE: mlx/backend/metal/kernels/fft/radix.h ================================================ // Copyright © 2024 Apple Inc. /* Radix kernels We provide optimized, single threaded Radix codelets for n=2,3,4,5,6,7,8,10,11,12,13. For n=2,3,4,5,6 we hand write the codelets. For n=8,10,12 we combine smaller codelets. For n=7,11,13 we use Rader's algorithm which decomposes them into (n-1)=6,10,12 codelets. */ #pragma once #include #include #include METAL_FUNC float2 complex_mul(float2 a, float2 b) { return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } // Complex mul followed by conjugate METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); } // Compute an FFT twiddle factor METAL_FUNC float2 get_twiddle(int k, int p) { float theta = -2.0f * k * M_PI_F / p; float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; return twiddle; } METAL_FUNC void radix2(thread float2* x, thread float2* y) { y[0] = x[0] + x[1]; y[1] = x[0] - x[1]; } METAL_FUNC void radix3(thread float2* x, thread float2* y) { float pi_2_3 = -0.8660254037844387; float2 a_1 = x[1] + x[2]; float2 a_2 = x[1] - x[2]; y[0] = x[0] + a_1; float2 b_1 = x[0] - 0.5 * a_1; float2 b_2 = pi_2_3 * a_2; float2 b_2_j = {-b_2.y, b_2.x}; y[1] = b_1 + b_2_j; y[2] = b_1 - b_2_j; } METAL_FUNC void radix4(thread float2* x, thread float2* y) { float2 z_0 = x[0] + x[2]; float2 z_1 = x[0] - x[2]; float2 z_2 = x[1] + x[3]; float2 z_3 = x[1] - x[3]; float2 z_3_i = {z_3.y, -z_3.x}; y[0] = z_0 + z_2; y[1] = z_1 + z_3_i; y[2] = z_0 - z_2; y[3] = z_1 - z_3_i; } METAL_FUNC void radix5(thread float2* x, thread float2* y) { float2 root_5_4 = 0.5590169943749475; float2 sin_2pi_5 = 0.9510565162951535; float2 sin_1pi_5 = 0.5877852522924731; float2 a_1 = x[1] + x[4]; float2 a_2 = x[2] + x[3]; float2 a_3 = x[1] - x[4]; float2 a_4 = x[2] - x[3]; float2 a_5 = a_1 + a_2; float2 a_6 = root_5_4 * (a_1 - a_2); float2 a_7 = x[0] - a_5 / 4; float2 a_8 = a_7 + a_6; float2 a_9 = a_7 - a_6; float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; float2 a_10_j = {a_10.y, -a_10.x}; float2 a_11_j = {a_11.y, -a_11.x}; y[0] = x[0] + a_5; y[1] = a_8 + a_10_j; y[2] = a_9 + a_11_j; y[3] = a_9 - a_11_j; y[4] = a_8 - a_10_j; } METAL_FUNC void radix6(thread float2* x, thread float2* y) { float sin_pi_3 = 0.8660254037844387; float2 a_1 = x[2] + x[4]; float2 a_2 = x[0] - a_1 / 2; float2 a_3 = sin_pi_3 * (x[2] - x[4]); float2 a_4 = x[5] + x[1]; float2 a_5 = x[3] - a_4 / 2; float2 a_6 = sin_pi_3 * (x[5] - x[1]); float2 a_7 = x[0] + a_1; float2 a_3_i = {a_3.y, -a_3.x}; float2 a_6_i = {a_6.y, -a_6.x}; float2 a_8 = a_2 + a_3_i; float2 a_9 = a_2 - a_3_i; float2 a_10 = x[3] + a_4; float2 a_11 = a_5 + a_6_i; float2 a_12 = a_5 - a_6_i; y[0] = a_7 + a_10; y[1] = a_8 - a_11; y[2] = a_9 + a_12; y[3] = a_7 - a_10; y[4] = a_8 + a_11; y[5] = a_9 - a_12; } METAL_FUNC void radix7(thread float2* x, thread float2* y) { // Rader's algorithm float2 inv = {1 / 6.0, -1 / 6.0}; // fft float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; radix6(in1, y + 1); y[0] = y[1] + x[0]; // b_q y[1] = complex_mul_conj(y[1], float2(-1, 0)); y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); // ifft radix6(y + 1, x + 1); y[1] = x[1] * inv + x[0]; y[5] = x[2] * inv + x[0]; y[4] = x[3] * inv + x[0]; y[6] = x[4] * inv + x[0]; y[2] = x[5] * inv + x[0]; y[3] = x[6] * inv + x[0]; } METAL_FUNC void radix8(thread float2* x, thread float2* y) { float cos_pi_4 = 0.7071067811865476; float2 w_0 = {cos_pi_4, -cos_pi_4}; float2 w_1 = {-cos_pi_4, -cos_pi_4}; float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; radix4(temp, x); radix4(temp + 4, x + 4); y[0] = x[0] + x[4]; y[4] = x[0] - x[4]; float2 x_5 = complex_mul(x[5], w_0); y[1] = x[1] + x_5; y[5] = x[1] - x_5; float2 x_6 = {x[6].y, -x[6].x}; y[2] = x[2] + x_6; y[6] = x[2] - x_6; float2 x_7 = complex_mul(x[7], w_1); y[3] = x[3] + x_7; y[7] = x[3] - x_7; } template METAL_FUNC void radix10(thread float2* x, thread float2* y) { float2 w[4]; w[0] = {0.8090169943749475, -0.5877852522924731}; w[1] = {0.30901699437494745, -0.9510565162951535}; w[2] = {-w[1].x, w[1].y}; w[3] = {-w[0].x, w[0].y}; if (raders_perm) { float2 temp[10] = { x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; radix5(temp, x); radix5(temp + 5, x + 5); } else { float2 temp[10] = { x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; radix5(temp, x); radix5(temp + 5, x + 5); } y[0] = x[0] + x[5]; y[5] = x[0] - x[5]; for (int t = 1; t < 5; t++) { float2 a = complex_mul(x[t + 5], w[t - 1]); y[t] = x[t] + a; y[t + 5] = x[t] - a; } } METAL_FUNC void radix11(thread float2* x, thread float2* y) { // Raders Algorithm float2 inv = {1 / 10.0, -1 / 10.0}; // fft radix10(x + 1, y + 1); y[0] = y[1] + x[0]; // b_q y[1] = complex_mul_conj(y[1], float2(-1, 0)); y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); // ifft radix10(y + 1, x + 1); y[1] = x[1] * inv + x[0]; y[6] = x[2] * inv + x[0]; y[3] = x[3] * inv + x[0]; y[7] = x[4] * inv + x[0]; y[9] = x[5] * inv + x[0]; y[10] = x[6] * inv + x[0]; y[5] = x[7] * inv + x[0]; y[8] = x[8] * inv + x[0]; y[4] = x[9] * inv + x[0]; y[2] = x[10] * inv + x[0]; } template METAL_FUNC void radix12(thread float2* x, thread float2* y) { float2 w[6]; float sin_pi_3 = 0.8660254037844387; w[0] = {sin_pi_3, -0.5}; w[1] = {0.5, -sin_pi_3}; w[2] = {0, -1}; w[3] = {-0.5, -sin_pi_3}; w[4] = {-sin_pi_3, -0.5}; if (raders_perm) { float2 temp[12] = { x[0], x[3], x[2], x[11], x[8], x[9], x[1], x[7], x[5], x[10], x[4], x[6]}; radix6(temp, x); radix6(temp + 6, x + 6); } else { float2 temp[12] = { x[0], x[2], x[4], x[6], x[8], x[10], x[1], x[3], x[5], x[7], x[9], x[11]}; radix6(temp, x); radix6(temp + 6, x + 6); } y[0] = x[0] + x[6]; y[6] = x[0] - x[6]; for (int t = 1; t < 6; t++) { float2 a = complex_mul(x[t + 6], w[t - 1]); y[t] = x[t] + a; y[t + 6] = x[t] - a; } } METAL_FUNC void radix13(thread float2* x, thread float2* y) { // Raders Algorithm float2 inv = {1 / 12.0, -1 / 12.0}; // fft radix12(x + 1, y + 1); y[0] = y[1] + x[0]; // b_q y[1] = complex_mul_conj(y[1], float2(-1, 0)); y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); // ifft radix12(y + 1, x + 1); y[1] = x[1] * inv + x[0]; y[7] = x[2] * inv + x[0]; y[10] = x[3] * inv + x[0]; y[5] = x[4] * inv + x[0]; y[9] = x[5] * inv + x[0]; y[11] = x[6] * inv + x[0]; y[12] = x[7] * inv + x[0]; y[6] = x[8] * inv + x[0]; y[3] = x[9] * inv + x[0]; y[8] = x[10] * inv + x[0]; y[4] = x[11] * inv + x[0]; y[2] = x[12] * inv + x[0]; } ================================================ FILE: mlx/backend/metal/kernels/fft/readwrite.h ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/metal/kernels/fft/radix.h" /* FFT helpers for reading and writing from/to device memory. For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT - RFFT - IRFFT Each with support for: - Contiguous reads - Padded reads - Strided reads */ #define MAX_RADIX 13 using namespace metal; template < typename in_T, typename out_T, int step = 0, bool four_step_real = false> struct ReadWriter { const device in_T* in; threadgroup float2* buf; device out_T* out; int n; int batch_size; int elems_per_thread; uint3 elem; uint3 grid; int threads_per_tg; bool inv; // Used for strided access int strided_device_idx = 0; int strided_shared_idx = 0; METAL_FUNC ReadWriter( const device in_T* in_, threadgroup float2* buf_, device out_T* out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_) : in(in_), buf(buf_), out(out_), n(n_), batch_size(batch_size_), elems_per_thread(elems_per_thread_), elem(elem_), grid(grid_), inv(inv_) { // Account for padding on last threadgroup threads_per_tg = elem.x == grid.x - 1 ? (batch_size - (grid.x - 1) * grid.y) * grid.z : grid.y * grid.z; } // ifft(x) = 1/n * conj(fft(conj(x))) METAL_FUNC float2 post_in(float2 elem) const { return inv ? float2(elem.x, -elem.y) : elem; } // Handle float case for generic RFFT alg METAL_FUNC float2 post_in(float elem) const { return float2(elem, 0); } METAL_FUNC float2 pre_out(float2 elem) const { return inv ? float2(elem.x / n, -elem.y / n) : elem; } METAL_FUNC float2 pre_out(float2 elem, int length) const { return inv ? float2(elem.x / length, -elem.y / length) : elem; } METAL_FUNC bool out_of_bounds() const { // Account for possible extra threadgroups int grid_index = elem.x * grid.y + elem.y; return grid_index >= batch_size; } METAL_FUNC void load() const { size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; // 2 complex64s = 128 bits constexpr int read_width = 2; for (short e = 0; e < (elems_per_thread / read_width); e++) { short index = read_width * tg_idx + read_width * threads_per_tg * e; index = metal::min(index, max_index); // vectorized reads buf[index] = post_in(in[batch_idx + index]); buf[index + 1] = post_in(in[batch_idx + index + 1]); } max_index += 1; if (elems_per_thread % 2 != 0) { short index = tg_idx + read_width * threads_per_tg * (elems_per_thread / read_width); index = metal::min(index, max_index); buf[index] = post_in(in[batch_idx + index]); } } METAL_FUNC void write() const { size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; constexpr int read_width = 2; for (short e = 0; e < (elems_per_thread / read_width); e++) { short index = read_width * tg_idx + read_width * threads_per_tg * e; index = metal::min(index, max_index); // vectorized reads out[batch_idx + index] = pre_out(buf[index]); out[batch_idx + index + 1] = pre_out(buf[index + 1]); } max_index += 1; if (elems_per_thread % 2 != 0) { short index = tg_idx + read_width * threads_per_tg * (elems_per_thread / read_width); index = metal::min(index, max_index); out[batch_idx + index] = pre_out(buf[index]); } } // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; threadgroup float2* seq_buf = buf + elem.y * n; for (int e = 0; e < elems_per_thread; e++) { int index = metal::min(fft_idx + e * m, n - 1); if (index < length) { float2 elem = post_in(in[batch_idx + index]); seq_buf[index] = complex_mul(elem, w_k[index]); } else { seq_buf[index] = 0.0; } } } METAL_FUNC void write_padded(int length, const device float2* w_k) const { size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; threadgroup float2* seq_buf = buf + elem.y * n; for (int e = 0; e < elems_per_thread; e++) { int index = metal::min(fft_idx + e * m, n - 1); if (index < length) { float2 elem = seq_buf[index + length - 1] * inv_factor; out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); } } } // Strided IO for four step FFT METAL_FUNC void compute_strided_indices(int stride, int overall_n) { // Use the batch threadgroup dimension to coalesce memory accesses: // e.g. stride = 12 // device | shared mem // 0 1 2 3 | 0 12 - - // - - - - | 1 13 - - // - - - - | 2 14 - - // 12 13 14 15 | 3 15 - - int coalesce_width = grid.y; int tg_idx = elem.y * grid.z + elem.z; int outer_batch_size = stride / coalesce_width; int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + overall_n * (elem.x / outer_batch_size); strided_device_idx = strided_batch_idx + tg_idx / coalesce_width * elems_per_thread * stride + tg_idx % coalesce_width; strided_shared_idx = (tg_idx % coalesce_width) * n + tg_idx / coalesce_width * elems_per_thread; } // Four Step FFT First Step METAL_FUNC void load_strided(int stride, int overall_n) { compute_strided_indices(stride, overall_n); for (int e = 0; e < elems_per_thread; e++) { buf[strided_shared_idx + e] = post_in(in[strided_device_idx + e * stride]); } } METAL_FUNC void write_strided(int stride, int overall_n) { for (int e = 0; e < elems_per_thread; e++) { float2 output = buf[strided_shared_idx + e]; int combined_idx = (strided_device_idx + e * stride) % overall_n; int ij = (combined_idx / stride) * (combined_idx % stride); // Apply four step twiddles at end of first step float2 twiddle = get_twiddle(ij, overall_n); out[strided_device_idx + e * stride] = complex_mul(output, twiddle); } } }; // Four Step FFT Second Step template <> METAL_FUNC void ReadWriter::load_strided( int stride, int overall_n) { // Silence compiler warnings (void)stride; (void)overall_n; // Don't invert between steps bool default_inv = inv; inv = false; load(); inv = default_inv; } template <> METAL_FUNC void ReadWriter::write_strided( int stride, int overall_n) { compute_strided_indices(stride, overall_n); for (int e = 0; e < elems_per_thread; e++) { float2 output = buf[strided_shared_idx + e]; out[strided_device_idx + e * stride] = pre_out(output, overall_n); } } // For RFFT, we interleave batches of two real sequences into one complex one: // // z_k = x_k + j.y_k // X_k = (Z_k + Z_(N-k)*) / 2 // Y_k = -j * ((Z_k - Z_(N-k)*) / 2) // // This roughly doubles the throughput over the regular FFT. template <> METAL_FUNC bool ReadWriter::out_of_bounds() const { int grid_index = elem.x * grid.y + elem.y; // We pack two sequences into one for RFFTs return grid_index * 2 >= batch_size; } template <> METAL_FUNC void ReadWriter::load() const { size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; short m = grid.z; short fft_idx = elem.z; for (int e = 0; e < elems_per_thread; e++) { int index = metal::min(fft_idx + e * m, n - 1); seq_buf[index].x = in[batch_idx + index]; seq_buf[index].y = in[batch_idx + index + next_in]; } } template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; size_t batch_idx = size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; float2 conj = {1, -1}; float2 minus_j = {0, -1}; short m = grid.z; short fft_idx = elem.z; for (int e = 0; e < elems_per_thread / 2 + 1; e++) { int index = metal::min(fft_idx + e * m, n_over_2 - 1); // x_0 = z_0.real // y_0 = z_0.imag if (index == 0) { out[batch_idx + index] = {seq_buf[index].x, 0}; out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; } else { float2 x_k = seq_buf[index]; float2 x_n_minus_k = seq_buf[n - index] * conj; out[batch_idx + index] = (x_k + x_n_minus_k) / 2; out[batch_idx + index + next_out] = complex_mul(((x_k - x_n_minus_k) / 2), minus_j); } } } template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; short m = grid.z; short fft_idx = elem.z; for (int e = 0; e < elems_per_thread; e++) { int index = metal::min(fft_idx + e * m, n - 1); if (index < length) { float2 elem = float2(in[batch_idx + index], in[batch_idx + index + next_in]); seq_buf[index] = complex_mul(elem, w_k[index]); } else { seq_buf[index] = 0; } } } template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; size_t batch_idx = size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length_over_2; float2 conj = {1, -1}; float2 inv_factor = {1.0f / n, -1.0f / n}; float2 minus_j = {0, -1}; short m = grid.z; short fft_idx = elem.z; for (int e = 0; e < elems_per_thread / 2 + 1; e++) { int index = metal::min(fft_idx + e * m, length_over_2 - 1); // x_0 = z_0.real // y_0 = z_0.imag if (index == 0) { float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); out[batch_idx + index] = float2(elem.x, 0); out[batch_idx + index + next_out] = float2(elem.y, 0); } else { float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); float2 x_n_minus_k = complex_mul( w_k[length - index], seq_buf[length - index] * inv_factor); x_n_minus_k *= conj; // w_k should happen before this extraction out[batch_idx + index] = (x_k + x_n_minus_k) / 2; out[batch_idx + index + next_out] = complex_mul(((x_k - x_n_minus_k) / 2), minus_j); } } } // For IRFFT, we do the opposite // // Z_k = X_k + j.Y_k // x_k = Re(Z_k) // Y_k = Imag(Z_k) template <> METAL_FUNC bool ReadWriter::out_of_bounds() const { int grid_index = elem.x * grid.y + elem.y; // We pack two sequences into one for IRFFTs return grid_index * 2 >= batch_size; } template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; size_t batch_idx = size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; short m = grid.z; short fft_idx = elem.z; float2 conj = {1, -1}; float2 plus_j = {0, 1}; for (int t = 0; t < elems_per_thread / 2 + 1; t++) { int index = metal::min(fft_idx + t * m, n_over_2 - 1); float2 x = in[batch_idx + index]; float2 y = in[batch_idx + index + next_in]; // NumPy forces first input to be real bool first_val = index == 0; // NumPy forces last input on even irffts to be real bool last_val = n % 2 == 0 && index == n_over_2 - 1; if (first_val || last_val) { x = float2(x.x, 0); y = float2(y.x, 0); } seq_buf[index] = x + complex_mul(y, plus_j); seq_buf[index].y = -seq_buf[index].y; if (index > 0 && !last_val) { seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); seq_buf[n - index].y = -seq_buf[n - index].y; } } } template <> METAL_FUNC void ReadWriter::write() const { int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; short m = grid.z; short fft_idx = elem.z; for (int e = 0; e < elems_per_thread; e++) { int index = metal::min(fft_idx + e * m, n - 1); out[batch_idx + index] = seq_buf[index].x / n; out[batch_idx + index + next_out] = seq_buf[index].y / -n; } } template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; size_t batch_idx = size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length_over_2; short m = grid.z; short fft_idx = elem.z; float2 conj = {1, -1}; float2 plus_j = {0, 1}; for (int t = 0; t < elems_per_thread / 2 + 1; t++) { int index = metal::min(fft_idx + t * m, n_over_2 - 1); float2 x = in[batch_idx + index]; float2 y = in[batch_idx + index + next_in]; if (index < length_over_2) { bool last_val = length % 2 == 0 && index == length_over_2 - 1; if (last_val) { x = float2(x.x, 0); y = float2(y.x, 0); } float2 elem1 = x + complex_mul(y, plus_j); seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); if (index > 0 && !last_val) { float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); seq_buf[length - index] = complex_mul(elem2 * conj, w_k[length - index]); } } else { short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); seq_buf[pad_index] = 0; seq_buf[pad_index + 1] = 0; } } } template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; short m = grid.z; short fft_idx = elem.z; float2 inv_factor = {1.0f / n, -1.0f / n}; for (int e = 0; e < elems_per_thread; e++) { int index = fft_idx + e * m; if (index < length) { float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); out[batch_idx + index] = output.x / length; out[batch_idx + index + next_out] = output.y / -length; } } } // Four Step RFFT template <> METAL_FUNC void ReadWriter::load_strided( int stride, int overall_n) { // Silence compiler warnings (void)stride; (void)overall_n; // Don't invert between steps bool default_inv = inv; inv = false; load(); inv = default_inv; } template <> METAL_FUNC void ReadWriter::write_strided( int stride, int overall_n) { int overall_n_over_2 = overall_n / 2 + 1; int coalesce_width = grid.y; int tg_idx = elem.y * grid.z + elem.z; int outer_batch_size = stride / coalesce_width; int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + overall_n_over_2 * (elem.x / outer_batch_size); strided_device_idx = strided_batch_idx + tg_idx / coalesce_width * elems_per_thread / 2 * stride + tg_idx % coalesce_width; strided_shared_idx = (tg_idx % coalesce_width) * n + tg_idx / coalesce_width * elems_per_thread / 2; for (int e = 0; e < elems_per_thread / 2; e++) { float2 output = buf[strided_shared_idx + e]; out[strided_device_idx + e * stride] = output; } // Add on n/2 + 1 element if (tg_idx == 0 && elem.x % outer_batch_size == 0) { out[strided_batch_idx + overall_n / 2] = buf[n / 2]; } } // Four Step IRFFT template <> METAL_FUNC void ReadWriter::load_strided( int stride, int overall_n) { int overall_n_over_2 = overall_n / 2 + 1; auto conj = float2(1, -1); compute_strided_indices(stride, overall_n); // Translate indices in terms of N - k for (int e = 0; e < elems_per_thread; e++) { int device_idx = strided_device_idx + e * stride; int overall_batch = device_idx / overall_n; int overall_index = device_idx % overall_n; if (overall_index < overall_n_over_2) { device_idx -= overall_batch * (overall_n - overall_n_over_2); buf[strided_shared_idx + e] = in[device_idx] * conj; } else { int conj_idx = overall_n - overall_index; device_idx = overall_batch * overall_n_over_2 + conj_idx; buf[strided_shared_idx + e] = in[device_idx]; } } } template <> METAL_FUNC void ReadWriter::load_strided( int stride, int overall_n) { // Silence compiler warnings (void)stride; (void)overall_n; bool default_inv = inv; inv = false; load(); inv = default_inv; } template <> METAL_FUNC void ReadWriter::write_strided( int stride, int overall_n) { compute_strided_indices(stride, overall_n); for (int e = 0; e < elems_per_thread; e++) { out[strided_device_idx + e * stride] = pre_out(buf[strided_shared_idx + e], overall_n).x; } } ================================================ FILE: mlx/backend/metal/kernels/fft.h ================================================ // Copyright © 2024 Apple Inc. // Metal FFT using Stockham's algorithm // // References: // - VkFFT (https://github.com/DTolm/VkFFT) // - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) #include #include "mlx/backend/metal/kernels/fft/radix.h" #include "mlx/backend/metal/kernels/fft/readwrite.h" #include "mlx/backend/metal/kernels/steel/defines.h" using namespace metal; #define MAX_RADIX 13 // Reached when elems_per_thread_ = 6, max_radix = 13 // and some threads have to do 3 radix 6s requiring 18 float2s. #define MAX_OUTPUT_SIZE 18 // Specialize for a particular value of N at runtime STEEL_CONST bool inv_ [[function_constant(0)]]; STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; // rader_m = n / rader_n STEEL_CONST int rader_m_ [[function_constant(3)]]; // Stockham steps STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; // Rader steps STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; // See "radix.h" for radix codelets typedef void (*RadixFunc)(thread float2*, thread float2*); // Perform a single radix n butterfly with appropriate twiddles template METAL_FUNC void radix_butterfly( int i, int p, thread float2* x, thread short* indices, thread float2* y) { // i: the index in the overall DFT that we're processing. // p: the size of the DFTs we're merging at this step. // m: how many threads are working on this DFT. int k, j; // Use faster bitwise operations when working with powers of two constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; if (radix_p_2 && is_power_of_2_) { constexpr short power = __builtin_ctz(radix); k = i & (p - 1); j = ((i - k) << power) + k; } else { k = i % p; j = (i / p) * radix * p + k; } // Apply twiddles if (p > 1) { float2 twiddle_1 = get_twiddle(k, radix * p); float2 twiddle = twiddle_1; x[1] = complex_mul(x[1], twiddle); STEEL_PRAGMA_UNROLL for (int t = 2; t < radix; t++) { twiddle = complex_mul(twiddle, twiddle_1); x[t] = complex_mul(x[t], twiddle); } } radix_func(x, y); STEEL_PRAGMA_UNROLL for (int t = 0; t < radix; t++) { indices[t] = j + t * p; } } // Perform all the radix steps required for a // particular radix size n. template METAL_FUNC void radix_n_steps( int i, thread int* p, int m, int n, int num_steps, thread float2* inputs, thread short* indices, thread float2* values, threadgroup float2* buf) { int m_r = n / radix; // When combining different sized radices, we have to do // multiple butterflies in a single thread. // E.g. n = 28 = 4 * 7 // 4 threads, 7 elems_per_thread // All threads do 1 radix7 butterfly. // 3 threads do 2 radix4 butterflies. // 1 thread does 1 radix4 butterfly. int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; int index = 0; int r_index = 0; for (int s = 0; s < num_steps; s++) { for (int t = 0; t < max_radices_per_thread; t++) { index = i + t * m; if (index < m_r) { for (int r = 0; r < radix; r++) { inputs[r] = buf[index + r * m_r]; } radix_butterfly( index, *p, inputs, indices + t * radix, values + t * radix); } } // Wait until all threads have read their inputs into thread local mem threadgroup_barrier(mem_flags::mem_threadgroup); for (int t = 0; t < max_radices_per_thread; t++) { index = i + t * m; if (index < m_r) { for (int r = 0; r < radix; r++) { r_index = t * radix + r; buf[indices[r_index]] = values[r_index]; } } } // Wait until all threads have written back to threadgroup mem threadgroup_barrier(mem_flags::mem_threadgroup); *p *= radix; } } #define RADIX_STEP(radix, radix_func, num_steps) \ radix_n_steps( \ fft_idx, p, m, n, num_steps, inputs, indices, values, buf); template METAL_FUNC void perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { float2 inputs[MAX_RADIX]; short indices[MAX_OUTPUT_SIZE]; float2 values[MAX_OUTPUT_SIZE]; RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); } // Each FFT is computed entirely in shared GPU memory. // // N is decomposed into radix-n DFTs: // e.g. 128 = 2 * 4 * 4 * 4 template [[kernel]] void fft( const device in_T* in [[buffer(0)]], device out_T* out [[buffer(1)]], constant const int& n, constant const int& batch_size, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { threadgroup float2 shared_in[tg_mem_size]; thread ReadWriter read_writer = ReadWriter( in, &shared_in[0], out, n, batch_size, elems_per_thread_, elem, grid, inv_); if (read_writer.out_of_bounds()) { return; }; read_writer.load(); threadgroup_barrier(mem_flags::mem_threadgroup); int p = 1; int fft_idx = elem.z; // Thread index in DFT int m = grid.z; // Threads per DFT int tg_idx = elem.y * n; // Index of this DFT in threadgroup threadgroup float2* buf = &shared_in[tg_idx]; perform_fft(fft_idx, &p, m, n, buf); read_writer.write(); } template [[kernel]] void rader_fft( const device in_T* in [[buffer(0)]], device out_T* out [[buffer(1)]], const device float2* raders_b_q [[buffer(2)]], const device short* raders_g_q [[buffer(3)]], const device short* raders_g_minus_q [[buffer(4)]], constant const int& n, constant const int& batch_size, constant const int& rader_n, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { // Use Rader's algorithm to compute fast FFTs // when a prime factor `p` of `n` is greater than 13 but // has `p - 1` Stockham decomposable into to prime factors <= 13. // // E.g. n = 102 // = 2 * 3 * 17 // . = 2 * 3 * RADER(16) // . = 2 * 3 * RADER(4 * 4) // // In numpy: // x_perm = x[g_q] // y = np.fft.fft(x_perm) * b_q // z = np.fft.ifft(y) + x[0] // out = z[g_minus_q] // out[0] = x[1:].sum() // // Where the g_q and g_minus_q are permutations formed // by the group under multiplicative modulo N using the // primitive root of N and b_q is a constant. // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm // // Rader's uses fewer operations than Bluestein's and so // is more accurate. It's also faster in most cases. threadgroup float2 shared_in[tg_mem_size]; thread ReadWriter read_writer = ReadWriter( in, &shared_in[0], out, n, batch_size, elems_per_thread_, elem, grid, inv_); if (read_writer.out_of_bounds()) { return; }; read_writer.load(); threadgroup_barrier(mem_flags::mem_threadgroup); // The number of the threads we're using for each DFT int m = grid.z; int fft_idx = elem.z; int tg_idx = elem.y * n; threadgroup float2* buf = &shared_in[tg_idx]; // rader_m = n / rader_n; int rader_m = rader_m_; // We have to load two x_0s for each thread since sometimes // elems_per_thread_ crosses a boundary. // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 short x_0_index = metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; // Do the Rader permutation in shared memory float2 temp[MAX_RADIX]; int max_index = n - rader_m - 1; for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); short g_q = raders_g_q[index / rader_m]; temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; } threadgroup_barrier(mem_flags::mem_threadgroup); for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); buf[index + rader_m] = temp[e]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Rader FFT on x[rader_m:] int p = 1; perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); // x_1 + ... + x_n is computed for us in the first FFT step so // we save it in the first rader_m indices of the array for later. int x_sum_index = metal::min(fft_idx, rader_m - 1); buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; float2 inv = {1.0f, -1.0f}; for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); short interleaved_index = index / rader_m + (index % rader_m) * (rader_n - 1); temp[e] = complex_mul( buf[rader_m + interleaved_index], raders_b_q[interleaved_index % (rader_n - 1)]); } threadgroup_barrier(mem_flags::mem_threadgroup); for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); buf[rader_m + index] = temp[e] * inv; } threadgroup_barrier(mem_flags::mem_threadgroup); // Rader IFFT on x[rader_m:] p = 1; perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); short diff_index = index / (rader_n - 1) - x_0_index; temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; } // Use the sum of elements that was computed in the first FFT float2 x_sum = buf[x_0_index] + x_0[0]; threadgroup_barrier(mem_flags::mem_threadgroup); for (int e = 0; e < elems_per_thread_; e++) { short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); short g_q_index = index % (rader_n - 1); short g_q = raders_g_minus_q[g_q_index]; short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); buf[out_index] = temp[e]; } buf[x_0_index * rader_n] = x_sum; threadgroup_barrier(mem_flags::mem_threadgroup); p = rader_n; perform_fft(fft_idx, &p, m, n, buf); read_writer.write(); } template [[kernel]] void bluestein_fft( const device in_T* in [[buffer(0)]], device out_T* out [[buffer(1)]], const device float2* w_q [[buffer(2)]], const device float2* w_k [[buffer(3)]], constant const int& length, constant const int& n, constant const int& batch_size, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { // Computes arbitrary length FFTs with Bluestein's algorithm // // In numpy: // bluestein_n = next_power_of_2(2*n - 1) // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) // // Where w_k and w_q are precomputed on CPU in high precision as: // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) // w_q = np.fft.fft(1/w_k[-n:]) threadgroup float2 shared_in[tg_mem_size]; thread ReadWriter read_writer = ReadWriter( in, &shared_in[0], out, n, batch_size, elems_per_thread_, elem, grid, inv_); if (read_writer.out_of_bounds()) { return; }; read_writer.load_padded(length, w_k); threadgroup_barrier(mem_flags::mem_threadgroup); int p = 1; int fft_idx = elem.z; // Thread index in DFT int m = grid.z; // Threads per DFT int tg_idx = elem.y * n; // Index of this DFT in threadgroup threadgroup float2* buf = &shared_in[tg_idx]; // fft perform_fft(fft_idx, &p, m, n, buf); float2 inv = float2(1.0f, -1.0f); for (int t = 0; t < elems_per_thread_; t++) { int index = fft_idx + t * m; buf[index] = complex_mul(buf[index], w_q[index]) * inv; } threadgroup_barrier(mem_flags::mem_threadgroup); // ifft p = 1; perform_fft(fft_idx, &p, m, n, buf); read_writer.write_padded(length, w_k); } template < int tg_mem_size, typename in_T, typename out_T, int step, bool real = false> [[kernel]] void four_step_fft( const device in_T* in [[buffer(0)]], device out_T* out [[buffer(1)]], constant const int& n1, constant const int& n2, constant const int& batch_size, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { // Fast four step FFT implementation for powers of 2. int overall_n = n1 * n2; int n = step == 0 ? n1 : n2; int stride = step == 0 ? n2 : n1; // The number of the threads we're using for each DFT int m = grid.z; int fft_idx = elem.z; threadgroup float2 shared_in[tg_mem_size]; threadgroup float2* buf = &shared_in[elem.y * n]; using read_writer_t = ReadWriter; read_writer_t read_writer = read_writer_t( in, &shared_in[0], out, n, batch_size, elems_per_thread_, elem, grid, inv_); if (read_writer.out_of_bounds()) { return; }; read_writer.load_strided(stride, overall_n); threadgroup_barrier(mem_flags::mem_threadgroup); int p = 1; perform_fft(fft_idx, &p, m, n, buf); read_writer.write_strided(stride, overall_n); } ================================================ FILE: mlx/backend/metal/kernels/fft.metal ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/fft.h" #define instantiate_fft(tg_mem_size, in_T, out_T) \ instantiate_kernel( \ "fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ fft, \ tg_mem_size, \ in_T, \ out_T) #define instantiate_rader(tg_mem_size, in_T, out_T) \ instantiate_kernel( \ "rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ rader_fft, \ tg_mem_size, \ in_T, \ out_T) #define instantiate_bluestein(tg_mem_size, in_T, out_T) \ instantiate_kernel( \ "bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ bluestein_fft, \ tg_mem_size, \ in_T, \ out_T) #define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \ instantiate_kernel( \ "four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \ four_step_fft, \ tg_mem_size, \ in_T, \ out_T, \ step, \ real) // clang-format off #define instantiate_ffts(tg_mem_size) \ instantiate_fft(tg_mem_size, float2, float2) \ instantiate_fft(tg_mem_size, float, float2) \ instantiate_fft(tg_mem_size, float2, float) \ instantiate_rader(tg_mem_size, float2, float2) \ instantiate_rader(tg_mem_size, float, float2) \ instantiate_rader(tg_mem_size, float2, float) \ instantiate_bluestein(tg_mem_size, float2, float2) \ instantiate_bluestein(tg_mem_size, float, float2) \ instantiate_bluestein(tg_mem_size, float2, float) \ instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \ instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \ instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \ instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \ instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \ instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true) // It's substantially faster to statically define the // threadgroup memory size rather than using // `setThreadgroupMemoryLength` on the compute encoder. // For non-power of 2 sizes we round up the shared memory. instantiate_ffts(256) instantiate_ffts(512) instantiate_ffts(1024) instantiate_ffts(2048) // 4096 is the max that will fit into 32KB of threadgroup memory. instantiate_ffts(4096) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/fp4.h ================================================ #pragma once struct fp4_e2m1 { fp4_e2m1(float x) { if (metal::isnan(x)) { bits = 0x7; return; } const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; x = metal::abs(x); if (x > 5.0f) { bits = 0x7; } else if (x >= 3.5f) { bits = 0x6; } else if (x > 2.5f) { bits = 0x5; } else if (x >= 1.75f) { bits = 0x4; } else if (x > 1.25f) { bits = 0x3; } else if (x >= 0.75f) { bits = 0x2; } else if (x > 0.25f) { bits = 0x1; } else { bits = 0x0; } bits |= sign_bit; } operator float16_t() { half converted = as_type(ushort((bits & 7) << 9)); converted *= 16384.0; return bits & 8 ? -converted : converted; } operator float() { return static_cast(this->operator float16_t()); } operator bfloat16_t() { return static_cast(this->operator float16_t()); } uint8_t bits; }; ================================================ FILE: mlx/backend/metal/kernels/fp8.h ================================================ #pragma once struct fp8_e4m3 { template fp8_e4m3(T f) { // From PyTorch // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 uint32_t fp8_max = 543 << 21; uint32_t denorm_mask = 141 << 23; uint32_t f_bits = as_type(static_cast(f)); uint32_t sign = f_bits & 0x80000000; f_bits ^= sign; if (f_bits >= fp8_max) { // Default behavior saturates to min/max bits = 0x7E; } else { if (f_bits < (121 << 23)) { f_bits = as_type( as_type(f_bits) + as_type(denorm_mask)); bits = static_cast(f_bits - denorm_mask); } else { // resulting mantissa is odd uint8_t mant_odd = (f_bits >> 20) & 1; f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; f_bits += mant_odd; bits = static_cast(f_bits >> 20); } } bits |= static_cast(sign >> 24); } operator float16_t() { uint16_t v = (bits & 127) << 7; half converted = as_type(v); converted *= 256.0; auto sign = bits & 128; return (sign ? -converted : converted); } operator bfloat16_t() { return static_cast(this->operator float16_t()); } operator float() { return static_cast(this->operator float16_t()); } uint8_t bits; }; struct fp8_e8m0 { fp8_e8m0(float x) { if (!metal::isfinite(x)) { bits = 0xFF; return; } if (x < 0.0f) { bits = 0x00; return; } float le = metal::log2(x); int n = int(metal::round(le)); n = n < -127 ? -127 : n; n = n > 127 ? 127 : n; bits = static_cast(n + 127); } operator bfloat16_t() { uint16_t out = (bits == 0 ? 0x40 : (static_cast(bits) << 7)); return as_type(out); } operator float() { uint32_t out = (bits == 0 ? 0x400000 : (static_cast(bits) << 23)); return as_type(out); } uint8_t bits; }; ================================================ FILE: mlx/backend/metal/kernels/fp_quantized.h ================================================ // Copyright © 2025 Apple Inc. #include #include #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; template inline constexpr short get_pack_factor() { return wsize / bits; } template inline constexpr short get_bytes_per_pack() { return wsize / 8; } template static inline T dequantize_scale(uint8_t s) { if constexpr (group_size == 16) { // Use nv scale return T(*(thread fp8_e4m3*)(&s)); } else { return T(*(thread fp8_e8m0*)(&s)); } } template struct Quantize { uint8_t operator()(float x) { if (bits == 8) { return fp8_e4m3(x).bits; } else { return fp4_e2m1(x).bits; } } }; template struct Dequantize { U operator()(uint8_t x) { if constexpr (bits == 8) { return U(*(thread fp8_e4m3*)(&x)); } else { return U(*(thread fp4_e2m1*)(&x)); } } }; template inline void load_vector(const device T* x, thread U* x_thread) { #pragma unroll for (int i = 0; i < values_per_thread; i++) { x_thread[i] = x[i]; } } template inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { for (int i = 0; i < N; i++) { x_thread[i] = x[i]; } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; } } template inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { U accum = 0; if constexpr (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); } } else { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * Dequantize<8>{}(w[i]); } } return scale * accum; } template inline U qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) { U accum = 0; if constexpr (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); } } else { for (int i = 0; i < N; i++) { accum += x_thread[i] * Dequantize<8>{}(w[i]); } } return scale * accum; } template inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { if constexpr (bits == 4) { for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * scale * Dequantize<4>{}(w[i]); result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4); } } else { for (int i = 0; i < values_per_thread; i++) { result[i] += x * scale * Dequantize<8>{}(w[i]); } } } template inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { if constexpr (bits == 4) { w_local[0] = scale * Dequantize<4, U>{}(w); w_local[1] = scale * Dequantize<4, U>{}(w >> 4); } else { w_local[0] = scale * Dequantize<8, U>{}(w); } } template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits> struct QuantizedBlockLoader { MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = group_size < BCOLS ? 1 : group_size / BCOLS; MLX_MTL_CONST short scale_step = group_size < BCOLS ? BCOLS / group_size : 1; static_assert( (n_reads * pack_factor) <= group_size, "The number of reads per thread must be less than the group size."); const int src_ld; const int tile_stride; short group_step_cnt; const int group_stride; const short thread_idx; const short bi; const short bj; threadgroup T* dst; const device uint8_t* src; const device uint8_t* scales; QuantizedBlockLoader( const device uint8_t* src_, const device uint8_t* scales_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales( scales_ + bi * src_ld / group_size + (bj * pack_factor) / group_size) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } T scale = dequantize_scale(*scales); for (int i = 0; i < n_reads; i++) { dequantize( src[i * bytes_per_pack], scale, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } T scale = dequantize_scale(*scales); for (int i = 0; i < n_reads; i++) { dequantize( src[i * bytes_per_pack], scale, dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { if (group_steps > 1) { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; scales++; } } else { scales += scale_step; } } else { scales += group_stride; } } }; template METAL_FUNC void fp_qmv_quad_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int pack_factor = get_pack_factor<32, bits>(); constexpr int values_per_thread = D / QUAD_SIZE; constexpr int steps_per_thread = values_per_thread < group_size ? 1 : values_per_thread / group_size; constexpr int values_per_step = values_per_thread / steps_per_thread; constexpr int packs_per_thread = values_per_thread / pack_factor; constexpr int packs_per_step = values_per_step / pack_factor; constexpr int results_per_quadgroup = 8; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; w += out_row * in_vec_size_w + quad_lid * packs_per_thread; scales += out_row * in_vec_size_g + (quad_lid * values_per_thread) / group_size; x += tid.x * in_vec_size + quad_lid * values_per_thread; y += tid.x * out_vec_size + out_row; load_vector(x, x_thread); for (int row = 0; row < results_per_quadgroup; row++) { auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; #pragma unroll for (int k = 0; k < steps_per_thread; ++k) { U s = dequantize_scale(sl[0]); if (row * quads_per_simd + out_row < out_vec_size) { result[row] += qdot( wl, x_thread + k * values_per_step, s); } sl++; wl += (sizeof(uint32_t) / sizeof(uint8_t)) * packs_per_step; } } for (int row = 0; row < results_per_quadgroup; row++) { result[row] = quad_sum(result[row]); if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { y[row * quads_per_simd] = static_cast(result[row]); } } } template METAL_FUNC void fp_qmv_fast_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int packs_per_thread = 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = get_pack_factor<32, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack<32>(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; x += block_size; } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } template METAL_FUNC void fp_qmv_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; constexpr int pack_factor = get_pack_factor<32, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack<32>(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); if (out_row >= out_vec_size) { return; } // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; uint8_t s = sl[0]; result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { load_vector_safe(x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); result[row] += qdot(wl, x_thread, s); } } for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } // In this case the last tile is moved back to redo some output values else { ws += used_out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { load_vector_safe(x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); result[row] += qdot_safe(wl, x_thread, s, remaining); } } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } } template METAL_FUNC void fp_qvm_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const int in_vec_size, const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int pack_factor = get_pack_factor<32, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int tn = group_size / pack_factor; constexpr int block_size = SIMD_SIZE; using W_T = uint32_t; const device W_T* ws = (const device W_T*)w; typedef float U; typedef struct { W_T wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; thread U result[tn * pack_factor] = {0}; thread U scale = 0; thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; // 32 * (tid.y * 2 + simd_gid) int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; x += tid.x * in_vec_size + simd_lid; y += tid.x * out_vec_size + out_col; if (out_col >= out_vec_size) { return; } // Loop over in_vec in blocks of block_size int remaining = in_vec_size % block_size; if (remaining == 0) { for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, result); x += block_size; scales += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } } else { for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, result); x += block_size; scales += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; } qouter( (thread uint8_t*)&w_local, x_local, scale, result); } // Accumulate in the simdgroup #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { result[k] = simd_sum(result[k]); } // Store the result if (simd_lid == 0) { #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { y[k] = static_cast(result[k]); } } } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void fp_qmm_t_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM || num_outs < BN) { mma_op.store_result_safe(y, N, short2(num_outs, num_els)); } else { mma_op.store_result(y, N); } } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void fp_qmm_n_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; auto wl = (const device uint8_t*)w; // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, num_els)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, BM)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM) { mma_op.store_result_safe(y, N, short2(BN, num_els)); } else { mma_op.store_result(y, N); } } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device uint8_t*& scales, device T*& y, int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; uint32_t w_idx = tid.z; if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; } else { ulong2 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, w_batch_ndims); w += idx.x; scales += idx.y; } y += tid.z * output_stride; } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device uint8_t*& scales, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, int output_stride, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; uint32_t w_idx; if (batch_ndims == 1) { x_idx = lhs_indices[tid.z * lhs_strides[0]]; w_idx = rhs_indices[tid.z * rhs_strides[0]]; } else { ulong2 idx = elem_to_loc_broadcast( tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); x_idx = lhs_indices[idx.x]; w_idx = rhs_indices[idx.y]; } if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; } else { ulong2 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, w_batch_ndims); w += idx.x; scales += idx.y; } y += tid.z * output_stride; } template [[kernel]] void fp_qmv_quad( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmv_quad_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); } template [[kernel]] void fp_qmv_fast( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void fp_qmv( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void fp_qvm( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void fp_qvm_split_k( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& final_block_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); // When (in_vec_size % split_k != 0) the final block needs to be smaller int in_vec_size_adj = tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; fp_qvm_impl( w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void fp_qmm_t( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void fp_qmm_n( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template [[kernel]] void fp_gather_qmv_fast( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void fp_gather_qmv( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void fp_gather_qvm( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void fp_gather_qmm_t( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void fp_gather_qmm_n( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, int group_size, int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> [[kernel]] void fp_gather_qmm_rhs( const device T* x, const device uint32_t* w, const device uint8_t* scales, const device uint32_t* indices, device T* y, const constant int& M, const constant int& N, const constant int& K, uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); using mma_t = mlx::steel::BlockMMA< T, T, BM, BN, BK, WM, WN, false, transpose, BK_padded, transpose ? BK_padded : BN_padded>; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, group_size, bits>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int N_w = N * bytes_per_pack / pack_factor; const int N_g = N / group_size; const int K_it = K / BK; const size_t stride_w = transpose ? N * K_w : K * N_w; const size_t stride_s = transpose ? N * K_g : K * N_g; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const size_t y_row_long = size_t(y_row); const size_t y_col_long = size_t(y_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); // Calculate the final tiles in the case that K is not aligned const int k_remain = K - K_it * BK; const short2 tile_x = short2(k_remain, tgp_bm); const short2 tile_w = transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); // Move x and output to the correct block auto wl = (const device uint8_t*)w; x += y_row_long * K; y += y_row_long * N + y_col_long; wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = indices[y_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (indices[y_row + n] != index) { offset_next = n; index_next = indices[y_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); thread loader_w_t loader_w( wl + index * stride_w, scales + index * stride_s, transpose ? K : N, Ws, simd_group_id, simd_lane_id); // Matrices are all aligned check nothing if (align_M && align_N) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } else { // Tile aligned so check outside of the hot loop if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } // Nothing aligned so check both rows and cols else { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } } } } template [[kernel]] void fp_quantize( const device T* w [[buffer(0)]], device uint8_t* out [[buffer(1)]], device uint8_t* scales [[buffer(2)]], uint2 tidx [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr bool use_mx_scale = group_size == 32; size_t index = tidx.x + grid_dim.x * size_t(tidx.y); float scale; float w_thread = w[index]; if (use_mx_scale) { scale = simd_max(abs(w_thread)); } else { float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); scale = tidx.x < 16 ? w_max_l : w_max_r; } scale /= bits == 4 ? 6.0f : 448.0f; using ScaleType = metal::conditional_t; auto s = ScaleType(scale); uint8_t q_scale = s.bits; scale = float(s); size_t gindex = index / group_size; if (index % group_size == 0) { scales[gindex] = q_scale; } uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); if (bits == 4) { uint8_t sval = simd_shuffle_down(output, 1); output |= sval << bits; } constexpr int pack_factor = bits == 8 ? 1 : 2; if (index % pack_factor == 0) { out[index / pack_factor] = output; } } template [[kernel]] void fp_dequantize( const device uint8_t* w [[buffer(0)]], const device uint8_t* scales [[buffer(1)]], device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr bool use_mx_scale = group_size == 32; constexpr int pack_factor = bits == 8 ? 1 : 2; size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; out += oindex; using ScaleType = metal::conditional_t; auto q_scale = ((device ScaleType*)(scales))[gindex]; auto scale = float(q_scale); uint val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 4) { d = (val >> (bits * i)) & 0x0f; } else if (bits == 8) { d = val; } out[i] = static_cast(scale * Dequantize{}(d)); } } template [[kernel]] void fp_quantize_dequantize( const device T* w [[buffer(0)]], device T* out [[buffer(1)]], uint2 tidx [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr bool use_mx_scale = group_size == 32; size_t index = tidx.x + grid_dim.x * size_t(tidx.y); float scale; float w_thread = w[index]; if (use_mx_scale) { scale = simd_max(abs(w_thread)); } else { float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); scale = tidx.x < 16 ? w_max_l : w_max_r; } scale /= bits == 4 ? 6.0f : 448.0f; using ScaleType = metal::conditional_t; auto s = ScaleType(scale); scale = float(s); uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); out[index] = static_cast(scale * Dequantize{}(output)); } ================================================ FILE: mlx/backend/metal/kernels/fp_quantized.metal ================================================ // Copyright © 2025 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/fp_quantized.h" #define instantiate_quantized(mode, name, type, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits, \ fp_ ## name, \ type, \ group_size, \ bits) #define instantiate_quantized_batched(mode, name, type, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \ fp_ ## name, \ type, \ group_size, \ bits, \ batched) #define instantiate_quantized_aligned(mode, name, type, aligned, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \ fp_ ## name, \ type, \ group_size, \ bits, \ aligned) #define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \ fp_ ## name, \ type, \ group_size, \ bits, \ aligned, \ batched) #define instantiate_quantized_quad(mode, name, type, D, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \ fp_ ## name, \ type, \ group_size, \ bits, \ D, \ batched) #define instantiate_quantized_split_k(mode, name, type, split_k, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \ fp_ ## name, \ type, \ group_size, \ bits, \ split_k) #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose, mode, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ func, \ type, \ group_size, \ bits, \ bm, \ bn, \ bk, \ wm, \ wn, \ transpose) #define instantiate_quantized_batched_wrap(name, type, mode, group_size, bits) \ instantiate_quantized_batched(mode, name, type, 1, group_size, bits) \ instantiate_quantized_batched(mode, name, type, 0, group_size, bits) #define instantiate_quantized_all_batched(type, mode, group_size, bits) \ instantiate_quantized_batched_wrap(qmv_fast, type, mode, group_size, bits) \ instantiate_quantized_batched_wrap(qmv, type, mode, group_size, bits) \ instantiate_quantized_batched_wrap(qvm, type, mode, group_size, bits) \ instantiate_quantized_batched_wrap(qmm_n, type, mode, group_size, bits) #define instantiate_quantized_all_single(type, mode, group_size, bits) \ instantiate_quantized(mode, gather_qmv_fast, type, group_size, bits) \ instantiate_quantized(mode, gather_qmv, type, group_size, bits) \ instantiate_quantized(mode, gather_qvm, type, group_size, bits) \ instantiate_quantized(mode, gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, mode, group_size, bits) \ instantiate_quantized_aligned(mode, gather_qmm_t, type, true, group_size, bits) \ instantiate_quantized_aligned(mode, gather_qmm_t, type, false, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t, type, true, 1, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t, type, true, 0, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t, type, false, 1, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t, type, false, 0, group_size, bits) #define instantiate_quantized_all_quad(type, mode, group_size, bits) \ instantiate_quantized_quad(mode, qmv_quad, type, 64, 1, group_size, bits) \ instantiate_quantized_quad(mode, qmv_quad, type, 64, 0, group_size, bits) \ instantiate_quantized_quad(mode, qmv_quad, type, 128, 1, group_size, bits) \ instantiate_quantized_quad(mode, qmv_quad, type, 128, 0, group_size, bits) #define instantiate_quantized_all_splitk(type, mode, group_size, bits) \ instantiate_quantized_split_k(mode, qvm_split_k, type, 8, group_size, bits) \ instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits) #define instantiate_quantized_all_rhs(type, mode, group_size, bits) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true, mode, group_size, bits) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false, mode, group_size, bits) #define instantiate_quantize_dequantize(type, mode, group_size, bits) \ instantiate_kernel( \ #mode "_quantize_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ fp_quantize_dequantize, \ type, \ group_size, \ bits) \ instantiate_kernel( \ #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ fp_quantize, \ type, \ group_size, \ bits) \ instantiate_kernel( \ #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ fp_dequantize, \ type, \ group_size, \ bits) #define instantiate_quantized_modes(type, mode, group_size, bits) \ instantiate_quantized_all_batched(type, mode, group_size, bits) \ instantiate_quantized_all_single(type, mode, group_size, bits) \ instantiate_quantized_all_quad(type, mode, group_size, bits) \ instantiate_quantized_all_splitk(type, mode, group_size, bits) \ instantiate_quantized_all_aligned(type, mode, group_size, bits) \ instantiate_quantized_all_rhs(type, mode, group_size, bits) \ instantiate_quantize_dequantize(type, mode, group_size, bits) #define instantiate_quantized_types(type) \ instantiate_quantized_modes(type, nvfp4, 16, 4) \ instantiate_quantized_modes(type, mxfp8, 32, 8) \ instantiate_quantized_modes(type, mxfp4, 32, 4) instantiate_quantized_types(float) instantiate_quantized_types(bfloat16_t) instantiate_quantized_types(float16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/fp_quantized_nax.h ================================================ // Copyright © 2025 Apple Inc. #include #include #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; template inline constexpr short get_pack_factor() { return wsize / bits; } template inline constexpr short get_bytes_per_pack() { return wsize / 8; } template static inline T dequantize_scale(uint8_t s) { if constexpr (group_size == 16) { // Use nv scale return T(*(thread fp8_e4m3*)(&s)); } else { return T(*(thread fp8_e8m0*)(&s)); } } template struct Quantize { uint8_t operator()(float x) { if (bits == 8) { return fp8_e4m3(x).bits; } else { return fp4_e2m1(x).bits; } } }; template struct Dequantize { U operator()(uint8_t x) { if constexpr (bits == 8) { return U(*(thread fp8_e4m3*)(&x)); } else { return U(*(thread fp4_e2m1*)(&x)); } } }; template inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { if constexpr (bits == 4) { w_local[0] = scale * Dequantize<4, U>{}(w); w_local[1] = scale * Dequantize<4, U>{}(w >> 4); } else { w_local[0] = scale * Dequantize<8, U>{}(w); } } template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits> struct QuantizedBlockLoader { MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short n_reads_per_scale = (n_reads * pack_factor) <= group_size ? n_reads : (group_size / pack_factor); MLX_MTL_CONST short n_steps_per_read = n_reads / n_reads_per_scale; MLX_MTL_CONST short n_groups = BCOLS / group_size; const int src_ld; const int tile_stride; const int group_stride; const short thread_idx; const short bi; const short bj; const short group_id; threadgroup T* dst; const device uint8_t* src; const device uint8_t* scales; QuantizedBlockLoader( const device uint8_t* src_, const device uint8_t* scales_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), group_id((bj * pack_factor) / group_size), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size + group_id) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } int k = 0; for (int i = 0; i < n_steps_per_read; i++) { T scale = dequantize_scale(scales[i]); for (int j = 0; j < n_reads_per_scale; j++) { dequantize( src[k * bytes_per_pack], scale, dst + k * pack_factor); k++; } } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } int k = 0; for (int i = 0; i < n_steps_per_read; i++) { T scale = dequantize_scale(scales[i]); for (int j = 0; j < n_reads_per_scale; j++) { dequantize( src[k * bytes_per_pack], scale, dst + k * pack_factor); k++; } } } void next() { src += tile_stride; if (reduction_dim == 1) { scales += n_groups; } else { scales += n_groups * group_stride; } } }; using namespace mlx::steel; template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> METAL_FUNC void fp_qmm_t_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, threadgroup Wtype* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); // Instantiate Loader using loader_w_t = QuantizedBlockLoader< Wtype, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the weight loader loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); constexpr bool transpose_a = false; constexpr bool transpose_b = true; const short sgp_sm = min(SM, short(M - (y_row + tm))); const bool is_unaligned_sm = (sgp_sm != SM); const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); using AccumType = float; NAXTile Dtile; Dtile.clear(); x += tm * K; dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedN.value) { loader_w.load_unsafe(); } else { loader_w.load_safe(short2(BK, tgp_bn)); } threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; if constexpr (kAlignedM.value) { Atile.load(x + kk1, K); } else { Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); } Btile.template load(Ws + tn * BK_padded + kk1); tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } x += BK; loader_w.next(); } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedM.value && kAlignedN.value) { Dtile.store(y + tm * N + tn, N); } else if (kAlignedM.value && sgp_sn == SN) { Dtile.store(y + tm * N + tn, N); } else { Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); } }); }); } template < typename T, const int group_size, const int bits, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> METAL_FUNC void fp_qmm_n_impl( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; (void)M; constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BN_padded = (BN + 16 / sizeof(T)); using loader_w_t = QuantizedBlockLoader< T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation // const short num_els = min(BM, M - y_row); // const short num_outs = min(BN, N - y_col); loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); const short ldb_tgp = BN_padded; constexpr bool transpose_a = false; constexpr bool transpose_b = false; using AccumType = float; NAXTile Dtile; Dtile.clear(); x += tm * K; for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; Atile.load(x + kk1, K); Btile.template load(Ws + tn + kk1 * ldb_tgp); tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } x += BK; loader_w.next(); } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); Dtile.store(y + tm * N + tn, N); } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device S*& scales, device T*& y, int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; uint32_t w_idx = tid.z; if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; } else { ulong2 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, w_batch_ndims); w += idx.x; scales += idx.y; } y += tid.z * output_stride; } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device S*& scales, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, int output_stride, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; uint32_t w_idx; if (batch_ndims == 1) { x_idx = lhs_indices[tid.z * lhs_strides[0]]; w_idx = rhs_indices[tid.z * rhs_strides[0]]; } else { ulong2 idx = elem_to_loc_broadcast( tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); x_idx = lhs_indices[idx.x]; w_idx = rhs_indices[idx.y]; } if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; } else { ulong2 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, w_batch_ndims); w += idx.x; scales += idx.y; } y += tid.z * output_stride; } template < typename T, const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> [[kernel]] void fp_qmm_t_nax( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); threadgroup Wtype Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmm_t_impl( w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool batched, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> [[kernel]] void fp_qmm_n_nax( const device uint32_t* w, const device uint8_t* scales, const device T* x, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); } fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> [[kernel]] void fp_gather_qmm_t_nax( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); threadgroup Wtype Ws[BN * BK_padded]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmm_t_impl( w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2, typename Wtype = bfloat> [[kernel]] void fp_gather_qmm_n_nax( const device uint32_t* w, const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T* y, const constant int& K, const constant int& N, const constant int& M, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; adjust_matrix_offsets( x, w, scales, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, tid); fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, int group_size, const int bits, int BM, int BN, int BK, int WM, int WN, bool transpose, typename Wtype = bfloat> [[kernel]] void fp_gather_qmm_rhs_nax( const device T* x, const device uint32_t* w, const device uint8_t* scales, const device uint32_t* indices, device T* y, const constant int& M, const constant int& N, const constant int& K, uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { constexpr int pack_factor = get_pack_factor<8, bits>(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); constexpr int BN_padded = (BN + 16 / sizeof(Wtype)); using loader_w_t = QuantizedBlockLoader< Wtype, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, group_size, bits>; threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int N_w = N * bytes_per_pack / pack_factor; const int N_g = N / group_size; const int K_it = K / BK; const size_t stride_w = transpose ? N * K_w : K * N_w; const size_t stride_s = transpose ? N * K_g : K * N_g; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const size_t y_row_long = size_t(y_row); const size_t y_col_long = size_t(y_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); // Calculate the final tiles in the case that K is not aligned const int k_remain = K - K_it * BK; const short2 tile_w = transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); // Move x and output to the correct block auto wl = (const device uint8_t*)w; x += y_row_long * K; y += y_row_long * N + y_col_long; wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); const short sgp_sm = align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); const short sgp_sn = align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); constexpr short BR = transpose ? TN : TK; constexpr short BC = transpose ? TK : TN; using AccumType = float; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = indices[y_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (indices[y_row + n] != index) { offset_next = n; index_next = indices[y_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation NAXTile Dtile; Dtile.clear(); const device T* xn = x + tm * K; // Prepare threadgroup loading operations thread loader_w_t loader_w( wl + index * stride_w, scales + index * stride_s, transpose ? K : N, Ws, simd_group_id, simd_lane_id); dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { for (int k = 0; k < K_it; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedN.value) { loader_w.load_unsafe(); } else { loader_w.load_safe( transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); } threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; if constexpr (kAlignedM.value) { Atile.load(xn + kk1, K); } else { Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); } if constexpr (transpose) { Btile.template load( Ws + tn * BK_padded + kk1); } else { Btile.template load( Ws + tn + kk1 * BN_padded); } tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } xn += BK; loader_w.next(); } if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_w.load_safe(tile_w); threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; const short psk = min(int(SK), max(0, (BK - kk1))); Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); if constexpr (transpose) { Btile.template load( Ws + tn * BK_padded + kk1); } else { Btile.template load( Ws + tn + kk1 * BN_padded); } tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } } threadgroup_barrier(mem_flags::mem_threadgroup); const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); // Store results to device memory if constexpr (kAlignedN.value) { if (m_lo_lim == 0 && m_hi_lim == SM) { Dtile.store(y + tm * N + tn, N); } else { Dtile.store_slice( y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); } } else { Dtile.store_slice( y + tm * N + tn, N, short2(0, m_lo_lim), short2(sgp_sn, m_hi_lim)); } }); }); } } ================================================ FILE: mlx/backend/metal/kernels/fp_quantized_nax.metal ================================================ // Copyright © 2025 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/steel/gemm/nax.h" #include "mlx/backend/metal/kernels/fp_quantized_nax.h" #define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \ fp_ ## name, \ type, \ group_size, \ bits, \ batched) #define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \ fp_ ## name, \ type, \ group_size, \ bits, \ aligned) #define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \ fp_ ## name, \ type, \ group_size, \ bits, \ aligned, \ batched) #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose, mode, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ func, \ type, \ group_size, \ bits, \ bm, \ bn, \ bk, \ wm, \ wn, \ transpose) #define instantiate_quantized_all_aligned(type, mode, group_size, bits) \ instantiate_quantized_aligned(mode, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true, group_size, bits) \ instantiate_quantized_aligned(mode, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1, group_size, bits) \ instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0, group_size, bits) #define instantiate_quantized_all_rhs(type, mode, group_size, bits) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true, mode, group_size, bits) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false, mode, group_size, bits) #define instantiate_quantized_modes(type, mode, group_size, bits) \ instantiate_quantized_all_aligned(type, mode, group_size, bits) \ instantiate_quantized_all_rhs(type, mode, group_size, bits) #define instantiate_quantized_types(type) \ instantiate_quantized_modes(type, nvfp4, 16, 4) \ instantiate_quantized_modes(type, mxfp8, 32, 8) \ instantiate_quantized_modes(type, mxfp4, 32, 4) instantiate_quantized_types(float) instantiate_quantized_types(bfloat16_t) instantiate_quantized_types(float16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/gemv.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// /// Matrix vector multiplication /////////////////////////////////////////////////////////////////////////////// #define MLX_MTL_CONST static constant constexpr const template struct DefaultAccT { using type = float; }; template <> struct DefaultAccT { using type = complex64_t; }; template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ typename AccT = typename DefaultAccT::type> struct GEMVKernel { using acc_type = AccT; MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int blockM = threadsM * TM; MLX_MTL_CONST int blockN = threadsN * TN; static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); static_assert( SN == 4 || SN == 8 || SN == 16 || SN == 32, "gemv block must have a width of 4, 8, 16, or 32"); // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up // into blocks of (blockM, blockN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each threadgroup has (threadsN, threadsM, 1) threads // // 1. A thread loads TN elements each from mat along TM rows // and the corresponding scalar from the vector // 2. The thread then multiplies and adds to accumulate its local result for // the block // 3. At the end, each thread has accumulated results over all blocks across // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated blockM outputs // // Edge case handling: // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results // remain zero) // * The last thread that partially overlaps with the matrix is shifted // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; template static METAL_FUNC void load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = static_cast(src[src_offset + tn]); } } template static METAL_FUNC void load_safe( const device T* src, thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = src_offset + tn < src_size ? static_cast(src[src_offset + tn]) : U(0); } } } static METAL_FUNC void run( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& matrix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { // Appease compiler (void)lid; // Thread local accumulation results thread AccT result[TM] = {0}; thread T inter[TN]; thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); const int sgN = BN != 1 ? (simd_gid % BN) : 0; const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; int bm = (simdM + thrM) * TM; int bn = (simdN + thrN) * TN; // Block position int out_row = tid.x * blockM + bm; // Exit simdgroup if rows out of bound if (out_row >= out_vec_size) return; // Adjust tail simdgroup to ensure in bound reads out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; // Advance matrix mat += out_row * matrix_ld; constexpr const uniform loop_stride = make_uniform(blockN); const uniform in_size = make_uniform(in_vec_size); const uniform n_iter = in_size / loop_stride; const uniform last_iter = loop_stride * n_iter; const uniform leftover = in_size - last_iter; // Loop over in_vec in blocks of blockN for (int i = 0; i < n_iter; ++i) { load_unsafe(in_vec, v_coeff, bn); // Per thread work loop int mat_offset = 0; MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { // Load for the row load_unsafe(mat, inter, mat_offset + bn); // Accumulate results MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } mat_offset += matrix_ld; } bn += blockN; } if (leftover > 0) { load_safe(in_vec, v_coeff, bn, in_size); // Per thread work loop MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { // Load for the row load_safe(&mat[tm * matrix_ld], inter, bn, in_size); // Accumulate results MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } } } // Simdgroup accumulations MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { MLX_MTL_PRAGMA_UNROLL for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { result[tm] += simd_shuffle_down(result[tm], sn); } } // Threadgroup accumulation results if (needs_tgp_reduction) { threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { tgp_results[tm] = result[tm]; } threadgroup_barrier(mem_flags::mem_none); if (sgN == 0) { MLX_MTL_PRAGMA_UNROLL for (int sgn = 1; sgn < BN; sgn++) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { result[tm] += tgp_results[sgn * (blockM + TM) + tm]; } } } } } // Write outputs if (simdN == 0 && thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { if (kDoAxpby) { out_vec[out_row + tm] = static_cast(alpha) * static_cast(result[tm]) + static_cast(beta) * bias[(out_row + tm) * bias_stride]; } else { out_vec[out_row + tm] = static_cast(result[tm]); } } } } }; /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ typename AccT = typename DefaultAccT::type> struct GEMVTKernel { using acc_type = AccT; MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int blockM = threadsM * TM; MLX_MTL_CONST int blockN = threadsN * TN; static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up // into blocks of (blockM, blockN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each threadgroup has (threadsN, threadsM, 1) threads // // 1. A thread loads TN elements each from mat along TM contiguous rows // and the corresponding scalar from the vector // 2. The thread then accumulates its local result for the block // 3. At the end, each thread has accumulated results over all blocks across // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results // remain zero) // * The last thread that partially overlaps with the matrix is shifted // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; static METAL_FUNC void run( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { // Appease compiler (void)lid; // Thread local accumulation results AccT result[TN] = {0}; T inter[TN]; AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); const int sgN = BN != 1 ? (simd_gid % BN) : 0; const int simdM = SM * sgM; const int simdN = SN * sgN; int cm = (simdM + thrM); int cn = (simdN + thrN); int bm = cm * TM; int bn = cn * TN; int out_col = tid.x * blockN + bn; constexpr const uniform loop_stride = make_uniform(blockM); const uniform in_size = make_uniform(in_vec_size); const uniform n_iter = in_size / loop_stride; const uniform last_iter = loop_stride * n_iter; const uniform leftover = in_size - last_iter; // Edgecase handling if (out_col < out_vec_size) { out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; // Per thread accumulation main loop for (int i = 0; i < n_iter; ++i) { // Adding a threadgroup_barrier improves performance slightly // This is possibly it may help exploit cache better threadgroup_barrier(mem_flags::mem_none); MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { v_coeff[tm] = static_cast(in_vec[bm + tm]); } MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { auto vc = static_cast(v_coeff[tm]); for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for (int tn = 0; tn < TN; tn++) { result[tn] += vc * inter[tn]; } } bm += blockM; } if (leftover > 0) { for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { v_coeff[tm] = static_cast(in_vec[bm + tm]); MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; } } } } // Simdgroup accumulations MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { MLX_MTL_PRAGMA_UNROLL for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { result[tn] += simd_shuffle_down(result[tn], SN * sm); } } // Threadgroup accumulation results if (needs_tgp_reduction) { threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { tgp_results[tn] = result[tn]; } threadgroup_barrier(mem_flags::mem_none); if (sgM == 0) { MLX_MTL_PRAGMA_UNROLL for (int sgm = 1; sgm < BM; sgm++) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tn] += tgp_results[sgm * (blockN + TN) + tn]; } } } } } // Threadgroup accumulation and writing out results if (cm == 0 && out_col < out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { if (kDoAxpby) { out_vec[out_col + j] = static_cast(alpha) * static_cast(result[j]) + static_cast(beta) * bias[(out_col + j) * bias_stride]; } else { out_vec[out_col + j] = static_cast(result[j]); } } } } }; /////////////////////////////////////////////////////////////////////////////// /// Matrix vector multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* vector_batch_stride [[buffer(11)]], const constant int64_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* bias_batch_stride [[buffer(13)]], const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); if (kDoAxpby) { bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); } } else { in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; if (kDoAxpby) { bias += tid.z * bias_batch_stride[0]; } } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, bias, out_vec, in_vec_size, out_vec_size, marix_ld, alpha, beta, bias_stride, gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } #define instantiate_gemv_helper( \ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ instantiate_kernel( \ "gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ "_tn" #tn "_nc" #nc "_axpby" #axpby, \ gemv, \ itype, \ bm, \ bn, \ sm, \ sn, \ tm, \ tn, \ nc, \ axpby) // clang-format off #define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on // clang-format off #define instantiate_gemv_blocks(name, itype) \ instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \ instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \ instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \ instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \ instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \ instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \ instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); instantiate_gemv_blocks(complex64, complex64_t); template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN> /* Thread cols (in elements) */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* index_batch_strides [[buffer(11)]], const constant int& vector_batch_ndim [[buffer(12)]], const constant int* vector_batch_shape [[buffer(13)]], const constant int64_t* vector_batch_stride [[buffer(14)]], const constant int& matrix_batch_ndim [[buffer(15)]], const constant int* matrix_batch_shape [[buffer(16)]], const constant int64_t* matrix_batch_stride [[buffer(17)]], const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* mat_indices [[buffer(19)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; uint32_t indx_mat; // Update batch offsets if (batch_ndim > 1) { const constant auto* veci_bstrides = index_batch_strides; const constant auto* mati_bstrides = index_batch_strides + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); indx_vec = vec_indices[batch_offsets.x]; indx_mat = mat_indices[batch_offsets.y]; } else { indx_vec = vec_indices[index_batch_strides[0] * tid.z]; indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; } if (vector_batch_ndim > 1) { in_vec += elem_to_loc( indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); } else { in_vec += indx_vec * vector_batch_stride[0]; } if (matrix_batch_ndim > 1) { mat += elem_to_loc( indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); } else { mat += indx_mat * matrix_batch_stride[0]; } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, bias, out_vec, in_vec_size, out_vec_size, marix_ld, alpha, beta, batch_ndim, // Not used gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } // clang-format off #define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ instantiate_kernel( \ "gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ "_sn" #sn "_tm" #tm "_tn" #tn, \ gemv_gather, itype, bm, bn, sm, sn, tm, tn) #define instantiate_gemv_bs_blocks(name, itype) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on instantiate_gemv_bs_blocks(float32, float); instantiate_gemv_bs_blocks(float16, half); instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); instantiate_gemv_bs_blocks(complex64, complex64_t); /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* vector_batch_stride [[buffer(11)]], const constant int64_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* bias_batch_stride [[buffer(13)]], const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); if (kDoAxpby) { bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); } } else { in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; if (kDoAxpby) { bias += tid.z * bias_batch_stride[0]; } } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, bias, out_vec, in_vec_size, out_vec_size, marix_ld, alpha, beta, bias_stride, gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } // clang-format off #define instantiate_gemv_t_helper( \ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ instantiate_kernel( \ "gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \ gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby) #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on // clang-format off #define instantiate_gemv_t_blocks(name, itype) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on // clang-format off instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(bfloat16, bfloat16_t); instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN> /* Thread cols (in elements) */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* index_batch_strides [[buffer(11)]], const constant int& vector_batch_ndim [[buffer(12)]], const constant int* vector_batch_shape [[buffer(13)]], const constant int64_t* vector_batch_stride [[buffer(14)]], const constant int& matrix_batch_ndim [[buffer(15)]], const constant int* matrix_batch_shape [[buffer(16)]], const constant int64_t* matrix_batch_stride [[buffer(17)]], const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* mat_indices [[buffer(19)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; uint32_t indx_mat; // Update batch offsets if (batch_ndim > 1) { const constant auto* veci_bstrides = index_batch_strides; const constant auto* mati_bstrides = index_batch_strides + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); indx_vec = vec_indices[batch_offsets.x]; indx_mat = mat_indices[batch_offsets.y]; } else { indx_vec = vec_indices[index_batch_strides[0] * tid.z]; indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; } if (vector_batch_ndim > 1) { in_vec += elem_to_loc( indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); } else { in_vec += indx_vec * vector_batch_stride[0]; } if (matrix_batch_ndim > 1) { mat += elem_to_loc( indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); } else { mat += indx_mat * matrix_batch_stride[0]; } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, bias, out_vec, in_vec_size, out_vec_size, marix_ld, alpha, beta, batch_ndim, // Not used, gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } // clang-format off #define instantiate_gemv_t_bs_helper( \ nm, itype, bm, bn, sm, sn, tm, tn) \ instantiate_kernel( \ "gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ "_sn" #sn "_tm" #tm "_tn" #tn, \ gemv_t_gather, itype, bm, bn, sm, sn, tm, tn) #define instantiate_gemv_t_bs_blocks(name, itype) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \ instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \ instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on // clang-format off instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float16, half); instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/gemv_masked.h ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; #define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") struct _NoMask { char x; constexpr METAL_FUNC operator bool() { return true; } constexpr METAL_FUNC operator bool() const threadgroup { return true; } constexpr METAL_FUNC operator bool() const device { return true; } constexpr METAL_FUNC operator bool() const constant { return true; } }; typedef struct _NoMask nomask_t; template struct ScaleOp { OutT scale; METAL_FUNC OutT apply(InT x) const { return static_cast(x) * scale; } }; template < typename T, typename out_mask_t, typename op_mask_t, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ typename AccT = float> struct GEMVKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int blockM = threadsM * TM; MLX_MTL_CONST int blockN = threadsN * TN; static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); static_assert( SN == 8 || SN == 16 || SN == 32, "gemv block must have a width of 8, 16, or 32"); static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; MLX_MTL_CONST bool has_mul_operand_mask = has_operand_mask && !metal::is_same_v; MLX_MTL_CONST bool has_mul_output_mask = has_output_mask && !metal::is_same_v; // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up // into blocks of (blockM, blockN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each threadgroup has (threadsN, threadsM, 1) threads // // 1. A thread loads TN elements each from mat along TM rows // and the corresponding scalar from the vector // 2. The thread then multiplies and adds to accumulate its local result for // the block // 3. At the end, each thread has accumulated results over all blocks across // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated blockM outputs // // Edge case handling: // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results // remain zero) // * The last thread that partially overlaps with the matrix is shifted // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; template static METAL_FUNC void load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = static_cast(src[src_offset + tn]); } } template static METAL_FUNC void load_safe( const device T* src, thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { dst[tn] = src_offset + tn < src_size ? static_cast(src[src_offset + tn]) : U(0); } } } static METAL_FUNC void run( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& matrix_ld [[buffer(6)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { // Appease compiler (void)lid; // Thread local accumulation results thread AccT result[TM] = {0}; thread T inter[TN]; thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); const int sgN = BN != 1 ? (simd_gid % BN) : 0; const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; int bm = (simdM + thrM) * TM; int bn = (simdN + thrN) * TN; // Block position int out_row = tid.x * blockM + bm; // Exit simdgroup if rows out of bound if (out_row >= out_vec_size) return; // Adjust tail simdgroup to ensure in bound reads out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; // Prepare mask offsets const constant int* out_mask_strides = mask_strides; const constant int* mat_mask_strides = mask_strides + (has_output_mask ? 2 : 0); const constant int* vec_mask_strides = mat_mask_strides + (has_operand_mask ? 2 : 0); const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); const int out_mask_offset = !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; int mat_mask_offset = !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; int vec_mask_offset = 0; const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; T out_scale{1}; // Check output mask if (has_output_mask) { auto mask_out = out_mask[out_mask_offset]; // Write zeros and return if mask is 0 if (!mask_out) { if (simdN == 0 && thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { out_vec[out_row + tm] = T(0.); } } return; } // Store scalar if multiplicative mask if (has_mul_output_mask) { out_scale = T(mask_out); } } // Advance matrix mat += out_row * matrix_ld; // Prepare for loop constexpr const uniform loop_stride = make_uniform(blockN); const uniform in_size = make_uniform(in_vec_size); const uniform n_iter = in_size / loop_stride; const uniform last_iter = loop_stride * n_iter; const uniform leftover = in_size - last_iter; // Loop over in_vec in blocks of blockN for (int i = 0; i < n_iter; ++i) { if (!has_operand_mask || (bool(mat_mask[mat_mask_offset]) && bool(vec_mask[vec_mask_offset]))) { T block_scale{1}; if (has_mul_operand_mask) { block_scale = T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } load_unsafe(in_vec, v_coeff, bn); // Apply scale if (has_mul_operand_mask) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { v_coeff[tn] *= block_scale; } } // Per thread work loop int mat_offset = 0; MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { // Load for the row load_unsafe(mat, inter, mat_offset + bn); // Accumulate results MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } mat_offset += matrix_ld; } } bn += blockN; mat_mask_offset += mat_mask_step; vec_mask_offset += vec_mask_step; } if (leftover > 0) { if (!has_operand_mask || (bool(mat_mask[mat_mask_offset]) && bool(vec_mask[vec_mask_offset]))) { T block_scale{1}; if (has_mul_operand_mask) { block_scale = T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } load_safe(in_vec, v_coeff, bn, in_size); // Apply scale if (has_mul_operand_mask) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { v_coeff[tn] *= block_scale; } } // Per thread work loop MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { // Load for the row load_safe(&mat[tm * matrix_ld], inter, bn, in_size); // Accumulate results MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } } } } // Apply out scale if (has_mul_output_mask) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { result[tm] *= out_scale; } } // Simdgroup accumulations MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { MLX_MTL_PRAGMA_UNROLL for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { result[tm] += simd_shuffle_down(result[tm], sn); } } // Threadgroup accumulation results if (needs_tgp_reduction) { threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { tgp_results[tm] = result[tm]; } threadgroup_barrier(mem_flags::mem_none); if (sgN == 0) { MLX_MTL_PRAGMA_UNROLL for (int sgn = 1; sgn < BN; sgn++) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { result[tm] += tgp_results[sgn * (blockM + TM) + tm]; } } } } } // Write outputs if (simdN == 0 && thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { out_vec[out_row + tm] = static_cast(result[tm]); } } } }; /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, typename out_mask_t, typename op_mask_t, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ typename AccT = float> struct GEMVTKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int blockM = threadsM * TM; MLX_MTL_CONST int blockN = threadsN * TN; static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; MLX_MTL_CONST bool has_mul_operand_mask = has_operand_mask && !metal::is_same_v; MLX_MTL_CONST bool has_mul_output_mask = has_output_mask && !metal::is_same_v; // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up // into blocks of (blockM, blockN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each threadgroup has (threadsN, threadsM, 1) threads // // 1. A thread loads TN elements each from mat along TM contiguous rows // and the corresponding scalar from the vector // 2. The thread then accumulates its local result for the block // 3. At the end, each thread has accumulated results over all blocks across // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results // remain zero) // * The last thread that partially overlaps with the matrix is shifted // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; static METAL_FUNC void run( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { // Appease compiler (void)lid; // Thread local accumulation results AccT result[TN] = {0}; T inter[TN]; AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); const int sgN = BN != 1 ? (simd_gid % BN) : 0; const int simdM = SM * sgM; const int simdN = SN * sgN; int cm = (simdM + thrM); int cn = (simdN + thrN); int bm = cm * TM; int bn = cn * TN; int out_col = tid.x * blockN + bn; // Prepare mask offsets const constant int* out_mask_strides = mask_strides; const constant int* mat_mask_strides = out_mask_strides + (has_output_mask ? 2 : 0); const constant int* vec_mask_strides = mat_mask_strides + (has_operand_mask ? 2 : 0); const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); const int out_mask_offset = !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; int mat_mask_offset = !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; int vec_mask_offset = 0; const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; T out_scale{1}; // Check output mask if (has_output_mask) { auto mask_out = out_mask[out_mask_offset]; // Write zeros and return if mask is 0 if (!mask_out) { if (cm == 0 && out_col < out_vec_size) { if (out_col + TN <= out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { out_vec[out_col + tn] = T(0.); } } else { for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { out_vec[out_col + tn] = T(0.); } } } return; } // Store scalar if multiplicative mask if (has_mul_output_mask) { out_scale = T(mask_out); } } // Prepare for loop constexpr const uniform loop_stride = make_uniform(blockM); const uniform in_size = make_uniform(in_vec_size); const uniform n_iter = in_size / loop_stride; const uniform last_iter = loop_stride * n_iter; const uniform leftover = in_size - last_iter; // Edgecase handling if (out_col < out_vec_size) { out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; // Per thread accumulation main loop for (int i = 0; i < n_iter; ++i) { // Adding a threadgroup_barrier improves performance slightly // This is possibly it may help exploit cache better threadgroup_barrier(mem_flags::mem_none); if (!has_operand_mask || (bool(mat_mask[mat_mask_offset]) && bool(vec_mask[vec_mask_offset]))) { T block_scale{1}; if (has_mul_operand_mask) { block_scale = T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { v_coeff[tm] = static_cast(in_vec[bm + tm]); } // Apply scale if (has_mul_operand_mask) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { v_coeff[tm] *= block_scale; } } MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for (int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; } } } bm += blockM; mat_mask_offset += mat_mask_step; vec_mask_offset += vec_mask_step; } if (leftover > 0) { if (!has_operand_mask || (bool(mat_mask[mat_mask_offset]) && bool(vec_mask[vec_mask_offset]))) { T block_scale{1}; if (has_mul_operand_mask) { block_scale = T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { v_coeff[tm] = static_cast(in_vec[bm + tm]); if (has_mul_operand_mask) { v_coeff[tm] *= block_scale; } MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; } } } } } // Apply out scale if (has_mul_output_mask) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tn] *= out_scale; } } // Simdgroup accumulations MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { MLX_MTL_PRAGMA_UNROLL for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { result[tn] += simd_shuffle_down(result[tn], SN * sm); } } // Threadgroup accumulation results if (needs_tgp_reduction) { threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { tgp_results[tn] = result[tn]; } threadgroup_barrier(mem_flags::mem_none); if (sgM == 0) { MLX_MTL_PRAGMA_UNROLL for (int sgm = 1; sgm < BM; sgm++) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tn] += tgp_results[sgm * (blockN + TN) + tn]; } } } } } // Threadgroup accumulation and writing out results if (cm == 0 && out_col < out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { out_vec[out_col + j] = static_cast(result[j]); } } } }; /////////////////////////////////////////////////////////////////////////////// /// Matrix vector multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, typename out_mask_t, typename op_mask_t, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch> /* Batch ndim > 1 */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_masked( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* vector_batch_stride [[buffer(11)]], const constant int64_t* matrix_batch_stride [[buffer(12)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], const constant int64_t* mask_batch_strides [[buffer(24)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; constexpr bool has_output_mask = !metal::is_same_v; // Update batch offsets if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); if (has_output_mask) { out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); mask_batch_strides += batch_ndim; } if (has_operand_mask) { const constant auto* mask_strides_mat = mask_batch_strides; const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); mat_mask += batch_offsets.x; vec_mask += batch_offsets.y; } } else { in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; if (has_output_mask) { out_mask += tid.z * mask_batch_strides[0]; mask_batch_strides += batch_ndim; } if (has_operand_mask) { mat_mask += tid.z * mask_batch_strides[0]; vec_mask += tid.z * mask_batch_strides[batch_ndim]; } } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, out_vec, in_vec_size, out_vec_size, marix_ld, out_mask, mat_mask, vec_mask, mask_strides, gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// template < typename T, typename out_mask_t, typename op_mask_t, const int BM, /* Threadgroup rows (in simdgroups) */ const int BN, /* Threadgroup cols (in simdgroups) */ const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch> /* Batch ndim > 1 */ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_masked( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], const constant int64_t* vector_batch_stride [[buffer(11)]], const constant int64_t* matrix_batch_stride [[buffer(12)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], const constant int64_t* mask_batch_strides [[buffer(24)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; constexpr bool has_output_mask = !metal::is_same_v; // Update batch offsets if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); if (has_output_mask) { out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); mask_batch_strides += batch_ndim; } if (has_operand_mask) { const constant auto* mask_strides_mat = mask_batch_strides; const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); mat_mask += batch_offsets.x; vec_mask += batch_offsets.y; } } else { in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; if (has_output_mask) { out_mask += tid.z * mask_batch_strides[0]; mask_batch_strides += batch_ndim; } if (has_operand_mask) { mat_mask += tid.z * mask_batch_strides[0]; vec_mask += tid.z * mask_batch_strides[batch_ndim]; } } out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, out_vec, in_vec_size, out_vec_size, marix_ld, out_mask, mat_mask, vec_mask, mask_strides, gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } ================================================ FILE: mlx/backend/metal/kernels/gemv_masked.metal ================================================ // Copyright © 2023-2024 Apple Inc. // clang-format off #include #include #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/gemv_masked.h" #define instantiate_gemv_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_kernel( \ "gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ "_tn" #tn "_nc" #nc, \ gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) #define instantiate_gemv_blocks(name, itype) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \ instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); #define instantiate_gemv_t_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_kernel( \ "gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ "_tn" #tn "_nc" #nc, \ gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) #define instantiate_gemv_t_blocks(name, itype) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \ instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/hadamard.h ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/backend/metal/kernels/steel/defines.h" using namespace metal; // Thread local Hadamard transform for 2^R template METAL_FUNC void radix_func(thread float* x) { constexpr short logR = __builtin_ctz(R); short h = 1; STEEL_PRAGMA_UNROLL for (short s = 0; s < logR; s++) { STEEL_PRAGMA_UNROLL for (short i = 0; i < R / 2; i++) { short k = i & (h - 1); short j = ((i - k) << 1) + k; float a = x[j]; float b = x[j + h]; x[j] = a + b; x[j + h] = a - b; } h <<= 1; } } template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], constant const float& scale, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { // Compute a Hadamard transform of size N = 2^k // // Equivalent to: // from scipy.linalg import hadamard // y = hadamard(len(x)) @ x constexpr short num_threads = N / max_radix; constexpr short logN = __builtin_ctz(N); constexpr short logR = __builtin_ctz(max_radix); constexpr short num_steps = logN / logR; constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); int batch_idx = elem.y * N * stride + elem.z; short i = elem.x; threadgroup T buf[N]; // Read values from device if (stride == 1) { STEEL_PRAGMA_UNROLL for (short j = 0; j < max_radix / read_width; j++) { short index = j * read_width * num_threads + i * read_width; STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { buf[index + r] = in[batch_idx + index + r]; } } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < max_radix; j++) { buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } threadgroup_barrier(mem_flags::mem_threadgroup); float x[max_radix]; short h = 1; STEEL_PRAGMA_UNROLL for (short s = 0; s < num_steps; s++) { short k = i & (h - 1); short j = ((i - k) << logR) + k; STEEL_PRAGMA_UNROLL for (short r = 0; r < max_radix; r++) { x[r] = buf[j + h * r]; } radix_func(x); STEEL_PRAGMA_UNROLL for (short r = 0; r < max_radix; r++) { buf[j + h * r] = T(x[r]); } h <<= logR; threadgroup_barrier(mem_flags::mem_threadgroup); } // Do the final radix // e.g. max_radix = 16 // N = 1024 = 16 * 16 * 4 if (final_radix > 1) { // Each thread does multiple butterflies STEEL_PRAGMA_UNROLL for (int t = 0; t < max_radix / final_radix; t++) { short index = i + t * num_threads; short k = index & (h - 1); short j = ((index - k) << logFinal) + k; STEEL_PRAGMA_UNROLL for (short r = 0; r < final_radix; r++) { x[r] = buf[j + h * r]; } radix_func(x); STEEL_PRAGMA_UNROLL for (short r = 0; r < final_radix; r++) { buf[j + h * r] = T(x[r]); } } threadgroup_barrier(mem_flags::mem_threadgroup); } // Write values to device if (stride == 1) { STEEL_PRAGMA_UNROLL for (short j = 0; j < max_radix / read_width; j++) { short index = j * read_width * num_threads + i * read_width; STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { out[batch_idx + index + r] = T(buf[index + r] * scale); } } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < max_radix; j++) { out[batch_idx + (j * num_threads + i) * stride] = buf[j * num_threads + i]; } } } template [[kernel]] void hadamard_m( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], constant const float& scale, uint3 elem [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { // Compute a Hadamard transform of size M // using a naive O(M^2) codelet. // // This kernel is the second stage in the computation // of a Hadamard transform of size M*N where N = 2^k. int index = elem.x * grid.y + elem.y; short i = index % (N / read_width); int batch_idx = index / (N / read_width) * M * N; float x[read_width][M]; STEEL_PRAGMA_UNROLL for (short c = 0; c < M; c++) { STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { x[r][c] = in[batch_idx + c * N + i * read_width + r]; } } STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { // This function is JIT compiled for M // using the Hadamard matrix strings in `metal/hadamard.cpp` hadamard_radix_m(x[r]); } // Write back to device STEEL_PRAGMA_UNROLL for (short c = 0; c < M; c++) { STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); } } } ================================================ FILE: mlx/backend/metal/kernels/indexing/gather.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/indexing/indexing.h" template METAL_FUNC void gather_impl( const device T* src [[buffer(0)]], device T* out [[buffer(1)]], const constant int* src_shape [[buffer(2)]], const constant int64_t* src_strides [[buffer(3)]], const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], const thread Indices& indices, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { LocT src_idx = 0; for (int i = 0; i < NIDX; ++i) { LocT idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); } else { idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); idx_loc += indices.row_contiguous[i] ? index.y : elem_to_loc( index.y, &indices.shapes[indices.ndim * i + 1], &indices.strides[indices.ndim * i + 1], indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); } auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); LocT out_idx = index.z; if (IDX_NDIM == 1) { out_idx += static_cast(grid_dim.z) * index.x; } else if (IDX_NDIM >= 2) { out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); } out[out_idx] = src[src_offset + src_idx]; } ================================================ FILE: mlx/backend/metal/kernels/indexing/gather_axis.h ================================================ // Copyright © 2025 Apple Inc. #pragma once template [[kernel]] void gather_axis( const device T* src [[buffer(0)]], const device IdxT* indices [[buffer(1)]], device T* out [[buffer(2)]], const constant int* shape [[buffer(3)]], const constant int64_t* src_strides [[buffer(4)]], const constant int64_t* idx_strides [[buffer(5)]], const constant size_t& ndim [[buffer(6)]], const constant int& axis [[buffer(7)]], const constant int& axis_size [[buffer(8)]], const constant size_t& src_ax_stride [[buffer(9)]], const constant size_t& idx_ax_stride [[buffer(10)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { LocT elem_idx = index.z * static_cast(grid_dim.x); LocT out_idx = elem_idx * grid_dim.y + index.x; LocT idx_loc = index.y * static_cast(idx_ax_stride); if (IdxC) { idx_loc += out_idx; } else { idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); } auto idx_val = indices[idx_loc]; if (is_signed_v) { idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; } LocT src_idx = idx_val * static_cast(src_ax_stride); if (SrcC) { src_idx += elem_idx * axis_size + index.x; } else { src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); } out_idx += index.y * static_cast(grid_dim.x); out[out_idx] = src[src_idx]; } ================================================ FILE: mlx/backend/metal/kernels/indexing/gather_front.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/indexing/indexing.h" template [[kernel]] void gather_front( const device T* src, const device IdxT* indices, device T* out, const constant int64_t& stride, const constant int& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto idx = offset_neg_idx(indices[index.y], size); LocT src_idx = static_cast(stride) * idx; LocT out_idx = static_cast(stride) * index.y; int s_idx = N * index.x; for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { out[out_idx + s_idx] = src[src_idx + s_idx]; } } ================================================ FILE: mlx/backend/metal/kernels/indexing/indexing.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include template struct Indices { const array buffers; const constant int* shapes; const constant int64_t* strides; const constant bool* row_contiguous; const int ndim; }; template METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { if (is_unsigned_v) { return idx; } else { return (idx < 0) ? idx + size : idx; } } ================================================ FILE: mlx/backend/metal/kernels/indexing/masked_scatter.h ================================================ // Copyright © 2025 Apple Inc. #pragma once constant mlx::os_log logger("mlx", "masked_assign"); template [[kernel]] void masked_assign_impl( const device bool* mask [[buffer(0)]], const device uint* scatter_offsets [[buffer(1)]], const device T* src [[buffer(2)]], device T* out [[buffer(3)]], const constant int* src_shapes [[buffer(4)]], const constant int64_t* src_strides [[buffer(5)]], const constant int& src_ndim [[buffer(6)]], const constant int64_t& src_batch_size [[buffer(7)]], const constant int64_t& mask_batch_size [[buffer(8)]], uint idx [[thread_position_in_grid]]) { const bool mask_value = mask[idx]; if (!mask_value) { return; } const uint src_index = scatter_offsets[idx]; if (src_index >= src_batch_size) { logger.log_debug("Out of bound read from src"); return; } const uint batch_idx = idx / mask_batch_size; if (src_contiguous) { out[idx] = src[batch_idx * src_batch_size + src_index]; } else { out[idx] = src[elem_to_loc( batch_idx * src_batch_size + src_index, src_shapes, src_strides, src_ndim)]; } } ================================================ FILE: mlx/backend/metal/kernels/indexing/scatter.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/indexing/indexing.h" template < typename T, typename IdxT, typename Op, int NIDX, bool UPD_ROW_CONTIG, int NWORK, typename LocT> METAL_FUNC void scatter_impl( const device T* updates, device mlx_atomic* out, const constant int* upd_shape, const constant int64_t* upd_strides, const constant size_t& upd_ndim, const constant size_t& upd_size, const constant int* out_shape, const constant int64_t* out_strides, const constant size_t& out_ndim, const constant int* axes, const constant size_t& idx_size, const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { Op op; auto ind_idx = gid.y * NWORK; LocT out_offset = 0; if (upd_size > 1) { out_offset = elem_to_loc( gid.x, upd_shape + indices.ndim, out_strides, out_ndim); } for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { LocT out_idx = out_offset; for (int i = 0; i < NIDX; ++i) { auto idx_loc = indices.row_contiguous[i] ? ind_idx : elem_to_loc( ind_idx, &indices.shapes[indices.ndim * i], &indices.strides[indices.ndim * i], indices.ndim); auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); out_idx += static_cast(idx_val) * static_cast(out_strides[ax]); } auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; if constexpr (!UPD_ROW_CONTIG) { upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); } op.atomic_update(out, updates[upd_idx], out_idx); } } template < typename T, typename IdxT, typename Op, bool OUT_ROW_CONTIG, bool UPD_ROW_CONTIG, bool UPD_SCALAR, int NWORK, int NDIM> [[kernel]] void slice_update_op_impl( const device T* updates [[buffer(0)]], device T* out [[buffer(1)]], const constant int* update_shape [[buffer(2)]], const constant int64_t* update_strides [[buffer(3)]], const constant int& update_ndim [[buffer(4)]], const constant int64_t& update_size [[buffer(5)]], const constant int64_t* output_strides [[buffer(6)]], const constant int64_t& output_offset [[buffer(7)]], uint3 gid [[thread_position_in_grid]], uint3 gsize [[threads_per_grid]]) { Op op; IdxT idx = IdxT(gid.z) * gsize.y + gid.y * gsize.x + gid.x * NWORK; IdxT out_idx; IdxT update_idx; if constexpr (OUT_ROW_CONTIG) { out_idx = idx; } else if constexpr (NDIM == 1) { out_idx = NWORK * gid.x * output_strides[0]; } else if constexpr (NDIM == 2) { out_idx = gid.y * output_strides[0] + NWORK * gid.x * output_strides[1]; } else if constexpr (NDIM == 3) { out_idx = gid.z * output_strides[0] + gid.y * output_strides[1] + NWORK * gid.x * output_strides[2]; } else { out_idx = elem_to_loc(idx, update_shape, output_strides, update_ndim); } if constexpr (UPD_SCALAR) { update_idx = 0; } else if constexpr (UPD_ROW_CONTIG) { update_idx = idx; } else if constexpr (NDIM == 1) { update_idx = NWORK * gid.x * update_strides[0]; } else if constexpr (NDIM == 2) { update_idx = gid.y * update_strides[0] + NWORK * gid.x * update_strides[1]; } else if constexpr (NDIM == 3) { update_idx = gid.z * update_strides[0] + gid.y * update_strides[1] + NWORK * gid.x * update_strides[2]; } else { update_idx = elem_to_loc(idx, update_shape, update_strides, update_ndim); } out += output_offset; if constexpr (OUT_ROW_CONTIG && (UPD_ROW_CONTIG || UPD_SCALAR)) { for (int j = 0; j < NWORK; j++) { out[out_idx] = op(out[out_idx], updates[update_idx]); out_idx++; if constexpr (!UPD_SCALAR) { update_idx++; } } } else { auto out_stride = output_strides[update_ndim - 1]; auto update_stride = update_strides[update_ndim - 1]; for (int j = 0; j < NWORK; j++) { out[out_idx] = op(out[out_idx], updates[update_idx]); out_idx += out_stride; if constexpr (!UPD_SCALAR) { update_idx += update_stride; } } } } ================================================ FILE: mlx/backend/metal/kernels/indexing/scatter_axis.h ================================================ // Copyright © 2025 Apple Inc. #pragma once template < typename T, typename IdxT, typename LocT, typename Op, bool UpdC, bool IdxC> [[kernel]] void scatter_axis( const device T* upd [[buffer(0)]], const device IdxT* indices [[buffer(1)]], device mlx_atomic* out [[buffer(2)]], const constant int* shape [[buffer(3)]], const constant int64_t* upd_strides [[buffer(4)]], const constant int64_t* idx_strides [[buffer(5)]], const constant size_t& ndim [[buffer(6)]], const constant int& axis [[buffer(7)]], const constant int& out_axis_size [[buffer(8)]], const constant size_t& upd_ax_stride [[buffer(9)]], const constant size_t& idx_ax_stride [[buffer(10)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { Op op; LocT elem_idx = index.z * static_cast(grid_dim.x); LocT idx_loc = index.y * static_cast(idx_ax_stride); if (IdxC) { idx_loc += elem_idx * grid_dim.y + index.x; } else { idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); } auto idx_val = indices[idx_loc]; if (is_signed_v) { idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; } LocT upd_idx = index.y * static_cast(upd_ax_stride); if (UpdC) { upd_idx += elem_idx * grid_dim.y + index.x; } else { upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); } LocT out_idx = elem_idx * static_cast(out_axis_size) + idx_val * grid_dim.x + index.x; op.atomic_update(out, upd[upd_idx], out_idx); } ================================================ FILE: mlx/backend/metal/kernels/layer_norm.metal ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/backend/metal/kernels/utils.h" using namespace metal; constant bool has_w [[function_constant(20)]]; template inline void initialize_buffer( threadgroup float* xs, uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { if (simd_group_id == 0) { for (int i = 0; i < N; i++) { xs[N * simd_lane_id + i] = 0; } } threadgroup_barrier(mem_flags::mem_threadgroup); } template inline void threadgroup_sum( thread float* x, threadgroup float* xs, uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { for (int i = 0; i < N; i++) { x[i] = simd_sum(x[i]); } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == 0) { for (int i = 0; i < N; i++) { xs[N * simd_group_id + i] = x[i]; } } threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N; i++) { x[i] = xs[N * simd_lane_id + i]; x[i] = simd_sum(x[i]); } } template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, const device T* b, device T* out, constant float& eps, constant uint& axis_size, constant uint& w_stride, constant uint& b_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; // Initialize the registers and threadgroup memory float thread_x[N_READS] = {0}; threadgroup float local_buffer[SIMD_SIZE] = {0}; initialize_buffer(local_buffer, simd_lane_id, simd_group_id); // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS; // Compute some variables for reading writing etc const bool safe = lid * N_READS + N_READS <= axis_size; const int n = axis_size - lid * N_READS; // Read the inputs if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; } } else { for (int i = 0; i < n; i++) { thread_x[i] = x[i]; } } // Compute the mean float mean = 0; for (int i = 0; i < N_READS; i++) { mean += thread_x[i]; } threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); mean /= axis_size; // Compute the normalizer float normalizer = 0; if (!safe) { for (int i = n; i < N_READS; i++) { thread_x[i] = mean; } } for (int i = 0; i < N_READS; i++) { thread_x[i] -= mean; normalizer += thread_x[i] * thread_x[i]; } threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { for (int i = 0; i < n; i++) { thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, const device T* b, device T* out, constant float& eps, constant uint& axis_size, constant uint& w_stride, constant uint& b_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; threadgroup float local_buffer[SIMD_SIZE]; initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; // Compute the mean float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { mean += x[i + r]; } } } } threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); mean /= axis_size; // Compute the normalizer float normalizer = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float t = x[i + r] - mean; normalizer += t * t; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float t = x[i + r] - mean; normalizer += t * t; } } } } threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = (x[r + i] - mean) * normalizer; out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = (x[r + i] - mean) * normalizer; out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; } } } } } template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, const device T* g, device T* gx, device T* gw, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Initialize the registers and threadgroup memory float thread_x[N_READS] = {0}; float thread_w[N_READS] = {0}; float thread_g[N_READS] = {0}; threadgroup float local_buffer[3 * SIMD_SIZE]; initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); // Compute some variables for reading writing etc const bool safe = lid * N_READS + N_READS <= axis_size; const int n = axis_size - lid * N_READS; // Read the inputs if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; thread_g[i] = g[i]; thread_w[i] = w[i * w_stride]; } } else { for (int i = 0; i < n; i++) { thread_x[i] = x[i]; thread_g[i] = g[i]; thread_w[i] = w[i * w_stride]; } } // Compute the mean float mean = 0; for (int i = 0; i < N_READS; i++) { mean += thread_x[i]; } threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); mean /= axis_size; // Compute the neccesary scaling factors using the mean if (!safe) { for (int i = n; i < N_READS; i++) { thread_x[i] = mean; } } float factors[3] = {0}; constexpr int meanwg = 0; constexpr int meanwgxc = 1; constexpr int normalizer2 = 2; for (int i = 0; i < N_READS; i++) { thread_x[i] -= mean; factors[meanwg] += thread_w[i] * thread_g[i]; factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; factors[normalizer2] += thread_x[i] * thread_x[i]; } threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); factors[meanwg] /= axis_size; factors[meanwgxc] /= axis_size; factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] *= normalizer; gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { for (int i = 0; i < n; i++) { thread_x[i] *= normalizer; gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, const device T* g, device T* gx, device T* gw, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; threadgroup float local_buffer[3 * SIMD_SIZE]; initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); // Compute the mean float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { mean += x[i + r]; } } } } threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); mean /= axis_size; // Compute the neccesary scaling factors using the mean float factors[3] = {0}; constexpr int meanwg = 0; constexpr int meanwgxc = 1; constexpr int normalizer2 = 2; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; factors[meanwg] += wg; factors[meanwgxc] += wg * t; factors[normalizer2] += t * t; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; factors[meanwg] += wg; factors[meanwgxc] += wg * t; factors[normalizer2] += t * t; } } } } threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); factors[meanwg] /= axis_size; factors[meanwgxc] /= axis_size; factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = (x[i + r] - mean) * normalizer; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - factors[meanwg]) - xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = (x[i + r] - mean) * normalizer; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - factors[meanwg]) - xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } } } } } } // clang-format off #define instantiate_layer_norm(name, itype) \ instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \ instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \ instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \ instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype) instantiate_layer_norm(float32, float) instantiate_layer_norm(float16, half) instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/logging.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320) #include namespace mlx { using os_log = metal::os_log; } // namespace mlx #else namespace mlx { struct os_log { constexpr os_log(constant char*, constant char*) constant {} template void log_debug(constant char*, Args...) const {} template void log_debug(constant char*, Args...) const constant {} }; } // namespace mlx #endif ================================================ FILE: mlx/backend/metal/kernels/logsumexp.h ================================================ // Copyright © 2025 Apple Inc. template [[kernel]] void logsumexp( const device T* in, device T* out, constant int& axis_size, uint gid [[threadgroup_position_in_grid]], uint _lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { int lid = _lid; constexpr int SIMD_SIZE = 32; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; AccT ld[N_READS]; in += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { ld[i] = AccT(in[i]); } } else { for (int i = 0; i < N_READS; i++) { ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; } } if (simd_group_id == 0) { local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min; for (int i = 0; i < N_READS; i++) { maxval = (maxval < ld[i]) ? ld[i] : maxval; } maxval = simd_max(maxval); if (simd_lane_id == 0) { local_max[simd_group_id] = maxval; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { maxval = simd_max(local_max[simd_lane_id]); if (simd_lane_id == 0) { local_max[0] = maxval; } } threadgroup_barrier(mem_flags::mem_threadgroup); maxval = local_max[0]; // Compute exp(x_i - maxval) and store the partial sums in local_normalizer AccT normalizer = 0; for (int i = 0; i < N_READS; i++) { normalizer += fast::exp(ld[i] - maxval); } normalizer = simd_sum(normalizer); if (simd_lane_id == 0) { local_normalizer[simd_group_id] = normalizer; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { normalizer = simd_sum(local_normalizer[simd_lane_id]); if (simd_lane_id == 0) { out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } } template [[kernel]] void logsumexp_looped( const device T* in, device T* out, constant int& axis_size, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { in += gid * size_t(axis_size); constexpr int SIMD_SIZE = 32; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; // Get the max and the normalizer in one go AccT prevmax; AccT maxval = Limits::finite_min; AccT normalizer = 0; for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; AccT vals[N_READS]; if (offset + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { vals[i] = AccT(in[offset + i]); } } else { for (int i = 0; i < N_READS; i++) { vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; for (int i = 0; i < N_READS; i++) { maxval = (maxval < vals[i]) ? vals[i] : maxval; } normalizer *= fast::exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { normalizer += fast::exp(vals[i] - maxval); } } prevmax = maxval; maxval = simd_max(maxval); normalizer *= fast::exp(prevmax - maxval); normalizer = simd_sum(normalizer); prevmax = maxval; if (simd_lane_id == 0) { local_max[simd_group_id] = maxval; } threadgroup_barrier(mem_flags::mem_threadgroup); maxval = simd_max(local_max[simd_lane_id]); normalizer *= fast::exp(prevmax - maxval); if (simd_lane_id == 0) { local_normalizer[simd_group_id] = normalizer; } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); if (lid == 0) { out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } ================================================ FILE: mlx/backend/metal/kernels/logsumexp.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include using namespace metal; // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/logsumexp.h" #define instantiate_logsumexp(name, itype) \ instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \ instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \ instantiate_logsumexp(float32, float) instantiate_logsumexp(float16, half) instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/quantized.h ================================================ // Copyright © 2023-2024 Apple Inc. #include #include constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; template inline constexpr short get_pack_factor() { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } template inline constexpr short get_bytes_per_pack() { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); } template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; x_thread[i] = x[i]; } } return sum; } template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < N; i++) { sum += x[i]; x_thread[i] = x[i]; } } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; } return sum; } template inline U qdot( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline U qdot_safe( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[8 * i] += x * ((w0 & 0x7) * scale + bias); result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); result[8 * i + 2] += x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); result[8 * i + 5] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); } } else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[5 * i]; uint8_t w1 = w[5 * i + 1]; uint8_t w2 = w[5 * i + 2]; uint8_t w3 = w[5 * i + 3]; uint8_t w4 = w[5 * i + 4]; result[8 * i] += x * ((w0 & 0x1f) * scale + bias); result[8 * i + 1] += x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); result[8 * i + 3] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); result[8 * i + 4] += x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); result[8 * i + 6] += x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[4 * i] += x * ((w0 & 0x3f) * scale + bias); result[4 * i + 1] += x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); result[4 * i + 2] += x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { result[i] += x * (scale * w[i] + bias); } } } template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), scale / static_cast(16.0f), scale / static_cast(64.0f)}; for (int i = 0; i < (N / 4); i++) { w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 3 * i; w_local[0] = (w[0] & 0x7) * scale + bias; w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } } else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 5 * i; w_local[0] = (w[0] & 0x1f) * scale + bias; w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; } } else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; } } } template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, "The group size should be larger than the columns"); static_assert( group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = group_size / BCOLS; const int src_ld; const int tile_stride; short group_step_cnt; const int group_stride; const short thread_idx; const short bi; const short bj; threadgroup T* dst; const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, bias, dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { if (group_steps > 1) { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; scales++; biases++; } } else { scales++; biases++; } } else { scales += group_stride; biases += group_stride; } } }; template METAL_FUNC void qmv_quad_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int pack_factor = 32 / bits; constexpr int values_per_thread = D / QUAD_SIZE; constexpr int packs_per_thread = values_per_thread / pack_factor; constexpr int scale_step_per_thread = group_size / values_per_thread; constexpr int results_per_quadgroup = 8; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; w += out_row * in_vec_size_w + quad_lid * packs_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; x += tid.x * in_vec_size + quad_lid * values_per_thread; y += tid.x * out_vec_size + out_row; U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_quadgroup; row++) { auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device T* sl = scales + row * in_vec_size_g * quads_per_simd; const device T* bl = biases + row * in_vec_size_g * quads_per_simd; U s = sl[0]; U b = bl[0]; if (row * quads_per_simd + out_row < out_vec_size) { result[row] += qdot(wl, x_thread, s, b, sum); } } for (int row = 0; row < results_per_quadgroup; row++) { result[row] = quad_sum(result[row]); if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { y[row * quads_per_simd] = static_cast(result[row]); } } } template METAL_FUNC void qmv_fast_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } template METAL_FUNC void qmv_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); if (out_row >= out_vec_size) { return; } // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { U sum = load_vector_safe( x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot_safe( wl, x_thread, s, b, sum, remaining); } } for (int row = 0; row < results_per_simdgroup && out_row + row < out_vec_size; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } // In this case the last tile is moved back to redo some output values else { ws += used_out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { U sum = load_vector_safe( x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot_safe( wl, x_thread, s, b, sum, remaining); } } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } } template METAL_FUNC void qvm_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const int in_vec_size, const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; using W_T = typename ConditionalType::type; const device W_T* ws = (const device W_T*)w; typedef float U; typedef struct { W_T wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; thread U result[tn * pack_factor] = {0}; thread U scale = 1; thread U bias = 0; thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; x += tid.x * in_vec_size + simd_lid; y += tid.x * out_vec_size + out_col; if (out_col >= out_vec_size) { return; } // Loop over in_vec in blocks of block_size int remaining = in_vec_size % block_size; if (remaining == 0) { for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += block_size; scales += block_size * out_vec_size_g; biases += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } } else { for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += block_size; scales += block_size * out_vec_size_g; biases += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; bias = 0; } qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); } // Accumulate in the simdgroup #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { result[k] = simd_sum(result[k]); } // Store the result if (simd_lid == 0) { #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { y[k] = static_cast(result[k]); } } } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM || num_outs < BN) { mma_op.store_result_safe(y, N, short2(num_outs, num_els)); } else { mma_op.store_result(y, N); } } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void qmm_n_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; auto wl = (const device uint8_t*)w; // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, num_els)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, BM)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM) { mma_op.store_result_safe(y, N, short2(BN, num_els)); } else { mma_op.store_result(y, N); } } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, device T*& y, int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; uint32_t w_idx = tid.z; if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, int output_stride, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; uint32_t w_idx; if (batch_ndims == 1) { x_idx = lhs_indices[tid.z * lhs_strides[0]]; w_idx = rhs_indices[tid.z * rhs_strides[0]]; } else { ulong2 idx = elem_to_loc_broadcast( tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); x_idx = lhs_indices[idx.x]; w_idx = rhs_indices[idx.y]; } if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template [[kernel]] void affine_qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_quad_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); } template [[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_fast_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qvm_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void affine_qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], const constant int& final_block_size [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); // When (in_vec_size % split_k != 0) the final block needs to be smaller int in_vec_size_adj = tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; qvm_impl( w, scales, biases, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void affine_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template [[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmv_fast_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmv_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qvm_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void affine_gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, int group_size, int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> [[kernel]] void affine_gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], const device T* biases [[buffer(3)]], const device uint32_t* indices [[buffer(4)]], device T* y [[buffer(5)]], const constant int& M [[buffer(6)]], const constant int& N [[buffer(7)]], const constant int& K [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); using mma_t = mlx::steel::BlockMMA< T, T, BM, BN, BK, WM, WN, false, transpose, BK_padded, transpose ? BK_padded : BN_padded>; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, group_size, bits>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int N_w = N * bytes_per_pack / pack_factor; const int N_g = N / group_size; const int K_it = K / BK; const size_t stride_w = transpose ? N * K_w : K * N_w; const size_t stride_s = transpose ? N * K_g : K * N_g; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const size_t y_row_long = size_t(y_row); const size_t y_col_long = size_t(y_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); // Calculate the final tiles in the case that K is not aligned const int k_remain = K - K_it * BK; const short2 tile_x = short2(k_remain, tgp_bm); const short2 tile_w = transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); // Move x and output to the correct block auto wl = (const device uint8_t*)w; x += y_row_long * K; y += y_row_long * N + y_col_long; wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; biases += transpose ? y_col_long * K_g : y_col / group_size; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = indices[y_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (indices[y_row + n] != index) { offset_next = n; index_next = indices[y_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); thread loader_w_t loader_w( wl + index * stride_w, scales + index * stride_s, biases + index * stride_s, transpose ? K : N, Ws, simd_group_id, simd_lane_id); // Matrices are all aligned check nothing if (align_M && align_N) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } else { // Tile aligned so check outside of the hot loop if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } // Nothing aligned so check both rows and cols else { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } } } } template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], device uint8_t* out [[buffer(1)]], device T* scales [[buffer(2)]], device T* biases [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; static_assert( group_size % simd_size == 0, "Group size must be divisible by simd size."); size_t offset = index.x + grid_dim.x * size_t(index.y); size_t in_index = offset * values_per_reduce; size_t out_index = power_of_2_bits ? offset * writes_per_pack : offset * bytes_per_pack / writes_per_reduce; float w_thread[values_per_reduce]; float w_min = Limits::max; float w_max = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { float val = w[in_index + i]; w_thread[i] = val; w_min = min(w_min, val); w_max = max(w_max, val); } w_min = simd_min(w_min); w_max = simd_max(w_max); float scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); scale = side ? scale : -scale; float edge = side ? w_min : w_max; float q0 = round(edge / scale); bool at_zero = q0 == 0.0f; scale = at_zero ? scale : edge / q0; float bias = at_zero ? 0 : edge; // Write out the scales and biases size_t gindex = in_index / group_size; if (in_index % group_size == 0) { scales[gindex] = static_cast(scale); biases[gindex] = static_cast(bias); } using OutType = metal::conditional_t; OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); if (bits == 8) { output = val; } else { output |= val << (bits * (i % pack_factor)); } if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); output |= static_cast(sval) << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } } else if (bits == 5) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; out[out_index + 3] = (output & 0xff000000) >> 24; out[out_index + 4] = (output & 0xff00000000) >> 32; } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; } } } template [[kernel]] void affine_dequantize( const device uint8_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; out += oindex; if (bits == 3) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x7) * scale + bias; out[1] = ((w[0] & 0x38) >> 3) * scale + bias; out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; out[3] = ((w[1] & 0xe) >> 1) * scale + bias; out[4] = ((w[1] & 0x70) >> 4) * scale + bias; out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } else if (bits == 5) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x1f) * scale + bias; out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; } else { uint val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; } else if (bits == 4) { d = (val >> (bits * i)) & 0x0f; } else if (bits == 8) { d = val; } out[i] = scale * d + bias; } } } ================================================ FILE: mlx/backend/metal/kernels/quantized.metal ================================================ // Copyright © 2023-2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/quantized.h" #define instantiate_quantized(name, type, group_size, bits) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits, \ name, \ type, \ group_size, \ bits) #define instantiate_quantized_batched(name, type, group_size, bits, batched) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \ name, \ type, \ group_size, \ bits, \ batched) #define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \ name, \ type, \ group_size, \ bits, \ aligned) #define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \ name, \ type, \ group_size, \ bits, \ aligned, \ batched) #define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \ name, \ type, \ group_size, \ bits, \ D, \ batched) #define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \ name, \ type, \ group_size, \ bits, \ split_k) #define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ func, \ type, \ group_size, \ bits, \ bm, \ bn, \ bk, \ wm, \ wn, \ transpose) #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) #define instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits) #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \ instantiate_quantized(affine_gather_qmv, type, group_size, bits) \ instantiate_quantized(affine_gather_qvm, type, group_size, bits) \ instantiate_quantized(affine_gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \ instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0) #define instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0) #define instantiate_quantized_all_splitk(type, group_size, bits) \ instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32) #define instantiate_quantized_all_rhs(type, group_size, bits) \ instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_all_splitk(type, group_size, bits) \ instantiate_quantized_all_rhs(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float16_t, group_size, bits) \ instantiate_quantized_funcs(bfloat16_t, group_size, bits) #define instantiate_quantized_groups(bits) \ instantiate_quantized_types(128, bits) \ instantiate_quantized_types(64, bits) \ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ instantiate_quantized_groups(5) \ instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) instantiate_quantized_all() // clang-format on ================================================ FILE: mlx/backend/metal/kernels/quantized_nax.h ================================================ // Copyright © 2023-2024 Apple Inc. #include #include using namespace metal; using namespace mlx::steel; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; template inline constexpr short get_pack_factor() { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } template inline constexpr short get_bytes_per_pack() { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); } template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; x_thread[i] = x[i]; } } return sum; } template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < N; i++) { sum += x[i]; x_thread[i] = x[i]; } } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; } return sum; } template inline U qdot( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline U qdot_safe( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[8 * i] += x * ((w0 & 0x7) * scale + bias); result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); result[8 * i + 2] += x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); result[8 * i + 5] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); } } else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[5 * i]; uint8_t w1 = w[5 * i + 1]; uint8_t w2 = w[5 * i + 2]; uint8_t w3 = w[5 * i + 3]; uint8_t w4 = w[5 * i + 4]; result[8 * i] += x * ((w0 & 0x1f) * scale + bias); result[8 * i + 1] += x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); result[8 * i + 3] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); result[8 * i + 4] += x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); result[8 * i + 6] += x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[4 * i] += x * ((w0 & 0x3f) * scale + bias); result[4 * i + 1] += x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); result[4 * i + 2] += x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { result[i] += x * (scale * w[i] + bias); } } } template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), scale / static_cast(16.0f), scale / static_cast(64.0f)}; for (int i = 0; i < (N / 4); i++) { w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 3 * i; w_local[0] = (w[0] & 0x7) * scale + bias; w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } } else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 5 * i; w_local[0] = (w[0] & 0x1f) * scale + bias; w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; } } else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; } } } template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, "The group size should be larger than the columns"); static_assert( group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = group_size / BCOLS; const int src_ld; const int tile_stride; short group_step_cnt; const int group_stride; const short thread_idx; const short bi; const short bj; threadgroup T* dst; const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, bias, dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { if (group_steps > 1) { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; scales++; biases++; } } else { scales++; biases++; } } else { scales += group_stride; biases += group_stride; } } }; template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short bits> struct QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, 32, bits> { MLX_MTL_CONST short group_size = 32; static_assert( BCOLS % group_size == 0, "The group size should be divisible by the columns"); static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short n_groups = BCOLS / group_size; static_assert( (BCOLS_PACKED / n_reads) == n_groups, "Other configurations are not yet supported"); const int src_ld; const int tile_stride; const int group_stride; const short thread_idx; const short bi; const short bj; const short group_id; threadgroup T* dst; const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), group_id((bj * pack_factor) / group_size), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size + group_id), biases(biases_ + bi * src_ld / group_size + group_id) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, bias, dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { // if (group_steps > 1) { // group_step_cnt++; // if (group_step_cnt == group_steps) { // group_step_cnt = 0; // scales++; // biases++; // } // } else { scales += n_groups; biases += n_groups; // } } else { scales += n_groups * group_stride; biases += n_groups * group_stride; } } }; template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, device T*& y, int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; uint32_t w_idx = tid.z; if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, int output_stride, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; uint32_t w_idx; if (batch_ndims == 1) { x_idx = lhs_indices[tid.z * lhs_strides[0]]; w_idx = rhs_indices[tid.z * rhs_strides[0]]; } else { ulong2 idx = elem_to_loc_broadcast( tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); x_idx = lhs_indices[idx.x]; w_idx = rhs_indices[idx.y]; } if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2> METAL_FUNC void qmm_t_nax_tgp_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); using loader_w_t = QuantizedBlockLoader< T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the weight loader loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); constexpr bool transpose_a = false; constexpr bool transpose_b = true; const short sgp_sm = min(SM, short(M - (y_row + tm))); const bool is_unaligned_sm = (sgp_sm != SM); const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); using AccumType = float; NAXTile Dtile; Dtile.clear(); x += tm * K; dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedN.value) { loader_w.load_unsafe(); } else { loader_w.load_safe(short2(BK, tgp_bn)); } threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; if constexpr (kAlignedM.value) { Atile.load(x + kk1, K); } else { Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); } Btile.template load(Ws + tn * BK_padded + kk1); tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } x += BK; loader_w.next(); } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedM.value && kAlignedN.value) { Dtile.store(y + tm * N + tn, N); } else if (kAlignedM.value && sgp_sn == SN) { Dtile.store(y + tm * N + tn, N); } else { Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); } }); }); } template < typename T, const int group_size, const int bits, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2> METAL_FUNC void qmm_n_nax_tgp_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; (void)M; static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BN_padded = (BN + 16 / sizeof(T)); using loader_w_t = QuantizedBlockLoader< T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation // const short num_els = min(BM, M - y_row); // const short num_outs = min(BN, N - y_col); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); const short ldb_tgp = BN_padded; constexpr bool transpose_a = false; constexpr bool transpose_b = false; using AccumType = float; NAXTile Dtile; Dtile.clear(); x += tm * K; for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; Atile.load(x + kk1, K); Btile.template load(Ws + tn + kk1 * ldb_tgp); tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } x += BK; loader_w.next(); } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); Dtile.store(y + tm * N + tn, N); } template < typename T, const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 64, const int BK = 32, const int BN = 64, const int WM = 2, const int WN = 2> [[kernel]] void affine_qmm_t_nax( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_t_nax_tgp_impl( w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool batched, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2> [[kernel]] void affine_qmm_n_nax( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_n_nax_tgp_impl( w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2> [[kernel]] void affine_gather_qmm_t_nax( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Ws[BN * BK_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_t_nax_tgp_impl( w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const int BM = 64, const int BK = 64, const int BN = 64, const int WM = 2, const int WN = 2> [[kernel]] void affine_gather_qmm_n_nax( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Ws[BK * BN_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_n_nax_tgp_impl( w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, int group_size, int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> [[kernel]] void affine_gather_qmm_rhs_nax( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], const device T* biases [[buffer(3)]], const device uint32_t* indices [[buffer(4)]], device T* y [[buffer(5)]], const constant int& M [[buffer(6)]], const constant int& N [[buffer(7)]], const constant int& K [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); using loader_w_t = QuantizedBlockLoader< T, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, group_size, bits>; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int N_w = N * bytes_per_pack / pack_factor; const int N_g = N / group_size; const int K_it = K / BK; const size_t stride_w = transpose ? N * K_w : K * N_w; const size_t stride_s = transpose ? N * K_g : K * N_g; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const size_t y_row_long = size_t(y_row); const size_t y_col_long = size_t(y_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); // Calculate the final tiles in the case that K is not aligned const int k_remain = K - K_it * BK; const short2 tile_w = transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); // Move x and output to the correct block auto wl = (const device uint8_t*)w; x += y_row_long * K; y += y_row_long * N + y_col_long; wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; biases += transpose ? y_col_long * K_g : y_col / group_size; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); const short sgp_sm = align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); const short sgp_sn = align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); constexpr short BR = transpose ? TN : TK; constexpr short BC = transpose ? TK : TN; using AccumType = float; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = indices[y_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (indices[y_row + n] != index) { offset_next = n; index_next = indices[y_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); NAXTile Dtile; Dtile.clear(); const device T* xn = x + tm * K; // Prepare threadgroup loading operations thread loader_w_t loader_w( wl + index * stride_w, scales + index * stride_s, biases + index * stride_s, transpose ? K : N, Ws, simd_group_id, simd_lane_id); dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { for (int k = 0; k < K_it; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedN.value) { loader_w.load_unsafe(); } else { loader_w.load_safe( transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); } threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; if constexpr (kAlignedM.value) { Atile.load(xn + kk1, K); } else { Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); } if constexpr (transpose) { Btile.template load(Ws + tn * BK_padded + kk1); } else { Btile.template load(Ws + tn + kk1 * BN_padded); } tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } xn += BK; loader_w.next(); } if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_w.load_safe(tile_w); threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; volatile int compiler_barrier; const short psk = min(int(SK), max(0, (BK - kk1))); Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); if constexpr (transpose) { Btile.template load(Ws + tn * BK_padded + kk1); } else { Btile.template load(Ws + tn + kk1 * BN_padded); } tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } } threadgroup_barrier(mem_flags::mem_threadgroup); const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); // Store results to device memory if constexpr (kAlignedN.value) { if (m_lo_lim == 0 && m_hi_lim == SM) { Dtile.store(y + tm * N + tn, N); } else { Dtile.store_slice( y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); } } else { Dtile.store_slice( y + tm * N + tn, N, short2(0, m_lo_lim), short2(sgp_sn, m_hi_lim)); } }); }); } } ================================================ FILE: mlx/backend/metal/kernels/quantized_nax.metal ================================================ // Copyright © 2023-2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/nax.h" #include "mlx/backend/metal/kernels/steel/gemm/loader.h" #include "mlx/backend/metal/kernels/quantized_nax.h" #define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits, \ name, \ type, \ group_size, \ bits, bm, bk, bn, wm, wn) #define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \ name, \ type, \ group_size, \ bits, \ batched, bm, bk, bn, wm, wn) #define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \ name, \ type, \ group_size, \ bits, \ aligned, bm, bk, bn, wm, wn) #define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \ name, \ type, \ group_size, \ bits, \ aligned, \ batched, bm, bk, bn, wm, wn) #define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ instantiate_kernel( \ #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ func, \ type, \ group_size, \ bits, \ bm, \ bn, \ bk, \ wm, \ wn, \ transpose) #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0) #define instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits) #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2) #define instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \ instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \ instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \ instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0) #define instantiate_quantized_all_rhs(type, group_size, bits) \ instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \ instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_rhs(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float16_t, group_size, bits) \ instantiate_quantized_funcs(bfloat16_t, group_size, bits) #define instantiate_quantized_groups(bits) \ instantiate_quantized_types(128, bits) \ instantiate_quantized_types(64, bits) \ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ instantiate_quantized_groups(5) \ instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) instantiate_quantized_all() // clang-format on ================================================ FILE: mlx/backend/metal/kernels/quantized_utils.h ================================================ // Copyright © 2023-2024 Apple Inc. #include #include template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template < bool rows_aligned, bool cols_aligned, bool transpose, typename T, typename mma_t, typename loader_a_t, typename loader_b_t> METAL_FUNC void gemm_loop_unaligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations, const short tgp_bm, const short tgp_bn, const short tgp_bk) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory if (rows_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(short2(tgp_bk, tgp_bm)); } if (cols_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe( transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template METAL_FUNC void gemm_loop_finalize( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const short2 tile_a, const short2 tile_b) { loader_a.load_safe(tile_a); loader_b.load_safe(tile_b); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } ================================================ FILE: mlx/backend/metal/kernels/random.metal ================================================ // Copyright © 2023 Apple Inc. #include "mlx/backend/metal/kernels/utils.h" static constexpr constant uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; union rbits { uint2 val; uchar4 bytes[2]; }; rbits threefry2x32_hash(const thread uint2& key, uint2 count) { uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; rbits v; v.val.x = count.x + ks[0]; v.val.y = count.y + ks[1]; for (int i = 0; i < 5; ++i) { for (auto r : rotations[i % 2]) { v.val.x += v.val.y; v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); v.val.y ^= v.val.x; } v.val.x += ks[(i + 1) % 3]; v.val.y += ks[(i + 2) % 3] + i + 1; } return v; } [[kernel]] void rbitsc( device const uint32_t* keys, device char* out, constant const bool& odd, constant const uint& bytes_per_key, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; auto key = uint2(keys[kidx], keys[kidx + 1]); auto half_size = grid_dim.y - odd; out += index.x * bytes_per_key; bool drop_last = odd && (index.y == half_size); auto bits = threefry2x32_hash( key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); size_t idx = size_t(index.y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[1][i]; } } } } [[kernel]] void rbits( device const uint32_t* keys, device char* out, constant const bool& odd, constant const uint& bytes_per_key, constant const int& ndim, constant const int* key_shape, constant const int64_t* key_strides, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); auto key = uint2(keys[k1_elem], keys[k2_elem]); auto half_size = grid_dim.y - odd; out += size_t(index.x) * bytes_per_key; bool drop_last = odd && (index.y == half_size); auto bits = threefry2x32_hash( key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); size_t idx = size_t(index.y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[1][i]; } } } } ================================================ FILE: mlx/backend/metal/kernels/reduce.h ================================================ #pragma once #include "mlx/backend/metal/kernels/reduction/reduce_all.h" #include "mlx/backend/metal/kernels/reduction/reduce_col.h" #include "mlx/backend/metal/kernels/reduction/reduce_init.h" #include "mlx/backend/metal/kernels/reduction/reduce_row.h" ================================================ FILE: mlx/backend/metal/kernels/reduce.metal ================================================ // Copyright © 2024 Apple Inc. #include #include // clang-format off #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/atomic.h" #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduce.h" #define instantiate_init_reduce(name, tname, type, op) \ instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op) instantiate_init_reduce(and, bool_, bool, And) instantiate_init_reduce(or, bool_, bool, Or) #define instantiate_init_sum_prod(name, op) \ instantiate_init_reduce(name, int32, int32_t, op) \ instantiate_init_reduce(name, int64, int64_t, op) \ instantiate_init_reduce(name, float16, float16_t, op) \ instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \ instantiate_init_reduce(name, float32, float, op) \ instantiate_init_reduce(name, complex64, complex64_t, op) instantiate_init_sum_prod(sum, Sum) instantiate_init_sum_prod(prod, Prod) #define instantiate_init_min_max(name, op) \ instantiate_init_reduce(name, bool_, bool, op) \ instantiate_init_reduce(name, int8, int8_t, op) \ instantiate_init_reduce(name, int16, int16_t, op) \ instantiate_init_reduce(name, int32, int32_t, op) \ instantiate_init_reduce(name, int64, int64_t, op) \ instantiate_init_reduce(name, uint8, uint8_t, op) \ instantiate_init_reduce(name, uint16, uint16_t, op) \ instantiate_init_reduce(name, uint32, uint32_t, op) \ instantiate_init_reduce(name, uint64, uint64_t, op) \ instantiate_init_reduce(name, float16, float16_t, op) \ instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \ instantiate_init_reduce(name, float32, float, op) \ instantiate_init_reduce(name, complex64, complex64_t, op) instantiate_init_min_max(min, Min) instantiate_init_min_max(max, Max) #define instantiate_all_reduce(name, itype, otype, op) \ instantiate_kernel("all_reduce_" #name, \ all_reduce, \ itype, otype, op) #define instantiate_col_reduce_small(name, itype, otype, op, dim) \ instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ col_reduce_small, \ itype, otype, op, int, dim) \ instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \ col_reduce_longcolumn, \ itype, otype, op, int, dim) \ instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \ col_reduce_small, \ itype, otype, op, int64_t, dim) \ instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \ col_reduce_longcolumn, \ itype, otype, op, int64_t, dim) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_looped, \ itype, otype, op, int, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_looped, \ itype, otype, op, int64_t, dim, bm, bn) #define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_2pass, \ itype, otype, op, int, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_2pass, \ itype, otype, op, int64_t, dim, bm, bn) #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) #define instantiate_col_reduce_general(name, itype, otype, op) \ instantiate_col_reduce_small(name, itype, otype, op, 1) \ instantiate_col_reduce_small(name, itype, otype, op, 2) \ instantiate_col_reduce_small(name, itype, otype, op, 5) \ instantiate_col_reduce_looped(name, itype, otype, op, 1) \ instantiate_col_reduce_looped(name, itype, otype, op, 2) \ instantiate_col_reduce_looped(name, itype, otype, op, 5) #define instantiate_row_reduce_small(name, itype, otype, op, dim) \ instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ row_reduce_small, \ itype, otype, op, int, dim) \ instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \ row_reduce_small, \ itype, otype, op, int64_t, dim) #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ row_reduce_looped, \ itype, otype, op, int, dim) \ instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \ row_reduce_looped, \ itype, otype, op, int64_t, dim) #define instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \ instantiate_row_reduce_small(name, itype, otype, op, 2) \ instantiate_row_reduce_small(name, itype, otype, op, 5) \ instantiate_row_reduce_looped(name, itype, otype, op, 1) \ instantiate_row_reduce_looped(name, itype, otype, op, 2) \ instantiate_row_reduce_looped(name, itype, otype, op, 5) \ instantiate_kernel("row_reduce_simple_" #name, \ row_reduce_simple, \ itype, otype, op) #define instantiate_reduce_functions(name, tname, itype, otype, op) \ instantiate_all_reduce(name##tname, itype, otype, op) \ instantiate_row_reduce_general(name##tname, itype, otype, op) \ instantiate_col_reduce_general(name##tname, itype, otype, op) #define instantiate_and_or(name, op) \ instantiate_reduce_functions(name, bool_, bool, bool, op) \ instantiate_reduce_functions(name, int16, int16_t, bool, op) \ instantiate_reduce_functions(name, int32, int32_t, bool, op) \ instantiate_reduce_functions(name, int64, int64_t, bool, op) instantiate_and_or(and, And) instantiate_and_or(or, Or) #define instantiate_sum_prod(name, op) \ instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \ instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \ instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \ instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \ instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \ instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \ instantiate_reduce_functions(name, float32, float, float, op) \ instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op) instantiate_sum_prod(sum, Sum) instantiate_sum_prod(prod, Prod) #define instantiate_min_max(name, op) \ instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \ instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \ instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \ instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \ instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \ instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \ instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \ instantiate_reduce_functions(name, float32, float, float, op) \ instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op) instantiate_min_max(min, Min) instantiate_min_max(max, Max) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/reduce_utils.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/atomic.h" #include "mlx/backend/metal/kernels/reduction/ops.h" ================================================ FILE: mlx/backend/metal/kernels/reduction/ops.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #define DEFINE_SIMD_REDUCE() \ template = true> \ T simd_reduce(T val) { \ return simd_reduce_impl(val); \ } \ \ template = true> \ T simd_reduce(T val) { \ for (short i = simd_size / 2; i > 0; i /= 2) { \ val = operator()(val, simd_shuffle_down(val, i)); \ } \ return val; \ } static constant constexpr const uint8_t simd_size = 32; union bool4_or_uint { bool4 b; unsigned int i; }; struct None { template void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_store_explicit(out, val, offset); } }; template struct And { DEFINE_SIMD_REDUCE() bool simd_reduce_impl(bool val) { return simd_all(val); } static constexpr constant bool init = true; void atomic_update( device mlx_atomic* out, bool val, int elem_idx, size_t offset = 0) { if (!val) { bool4_or_uint update; update.b = {true, true, true, true}; update.b[elem_idx] = false; mlx_atomic_fetch_and_explicit(out, update.i, offset); } } void atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (!val) { mlx_atomic_store_explicit(out, val, offset); } } // Non atomic update void update(device bool* out, bool val) { *out &= val; } // Operator bool operator()(bool a, bool b) { return a && b; } }; template struct Or { DEFINE_SIMD_REDUCE() bool simd_reduce_impl(bool val) { return simd_any(val); } static constexpr constant bool init = false; void atomic_update( device mlx_atomic* out, bool val, int elem_idx, size_t offset = 0) { if (val) { bool4_or_uint update; update.b = {false, false, false, false}; update.b[elem_idx] = true; mlx_atomic_fetch_or_explicit(out, update.i, offset); } } void atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (val) { mlx_atomic_store_explicit(out, val, offset); } } // Non atomic update void update(device bool* out, bool val) { *out |= val; } // Operator bool operator()(bool a, bool b) { return a || b; } }; template struct Sum { DEFINE_SIMD_REDUCE() template T simd_reduce_impl(T val) { return simd_sum(val); } static constexpr constant U init = U(0); template void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_add_explicit(out, val, offset); } // Operator U operator()(U a, U b) { return a + b; } }; template struct Prod { DEFINE_SIMD_REDUCE() template T simd_reduce_impl(T val) { return simd_product(val); } static constexpr constant U init = U(1); template void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_mul_explicit(out, val, offset); } // Operator U operator()(U a, U b) { return a * b; } }; template struct Min { DEFINE_SIMD_REDUCE() template metal::enable_if_t, T> simd_reduce_impl(T val) { return simd_min(val); } template metal::enable_if_t, T> simd_reduce_impl(T val) { if (simd_any(val != val)) { return static_cast(NAN); } return simd_min(val); } static constexpr constant U init = Limits::max; template void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_min_explicit(out, val, offset); } // Operator template metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } template metal::enable_if_t, T> operator()(T a, T b) { if (metal::isnan(a) || metal::isnan(b)) { return static_cast(NAN); } else { return a < b ? a : b; } } template <> complex64_t operator()(complex64_t a, complex64_t b) { bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); if (!real_is_nan && !imag_is_nan) { return a < b ? a : b; } else if (real_is_nan && !imag_is_nan) { return complex64_t( static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); } else if (!real_is_nan && imag_is_nan) { return complex64_t( a.real < b.real ? a.real : b.real, static_cast(NAN)); } else { return complex64_t(static_cast(NAN), static_cast(NAN)); } }; }; template struct Max { DEFINE_SIMD_REDUCE() template metal::enable_if_t, T> simd_reduce_impl(T val) { return simd_max(val); } template metal::enable_if_t, T> simd_reduce_impl(T val) { if (simd_any(val != val)) { return static_cast(NAN); } return simd_max(val); } static constexpr constant U init = Limits::min; template void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_max_explicit(out, val, offset); } // Operator template metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } template metal::enable_if_t, T> operator()(T a, T b) { if (metal::isnan(a) || metal::isnan(b)) { return static_cast(NAN); } else { return a > b ? a : b; } } template <> complex64_t operator()(complex64_t a, complex64_t b) { bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); if (!real_is_nan && !imag_is_nan) { return a > b ? a : b; } else if (real_is_nan && !imag_is_nan) { return complex64_t( static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); } else if (!real_is_nan && imag_is_nan) { return complex64_t( a.real > b.real ? a.real : b.real, static_cast(NAN)); } else { return complex64_t(static_cast(NAN), static_cast(NAN)); } } }; ================================================ FILE: mlx/backend/metal/kernels/reduction/reduce_all.h ================================================ // Copyright © 2023-2024 Apple Inc. template < typename T, typename U, typename Op, typename IdxT = int64_t, int N_READS = REDUCE_N_READS> [[kernel]] void all_reduce( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& in_size [[buffer(2)]], const constant size_t& row_size [[buffer(3)]], uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; threadgroup U shared_vals[simd_size]; U total = Op::init; IdxT start_idx = gid.y * IdxT(row_size); IdxT actual_row = (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; IdxT blocks = actual_row / (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS); extra -= lid.x * N_READS; start_idx += lid.x * N_READS; in += start_idx; if (extra >= N_READS) { blocks++; extra = 0; } for (IdxT b = 0; b < blocks; b++) { for (int i = 0; i < N_READS; i++) { total = op(static_cast(in[i]), total); } in += lsize.x * N_READS; } if (extra > 0) { for (int i = 0; i < extra; i++) { total = op(static_cast(in[i]), total); } } // Reduction within simd group total = op.simd_reduce(total); if (simd_per_group > 1) { if (simd_lane_id == 0) { shared_vals[simd_group_id] = total; } // Reduction within thread group threadgroup_barrier(mem_flags::mem_threadgroup); total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; total = op.simd_reduce(total); } if (lid.x == 0) { out[gid.y] = total; } } ================================================ FILE: mlx/backend/metal/kernels/reduction/reduce_col.h ================================================ // Copyright © 2023-2024 Apple Inc. template [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) { constexpr int n_reads = 4; Op op; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; U totals[n_reads]; for (int i = 0; i < n_reads; i++) { totals[i] = Op::init; } IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; if (column >= reduction_stride) { return; } bool safe = column + n_reads <= reduction_stride; IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(lid.y, reduce_shape, reduce_strides); for (IdxT r = lid.y; r < total_rows; r += lsize.y) { row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { totals[i] = op(static_cast(row[i]), totals[i]); } } else { U vals[n_reads]; for (int i = 0; i < n_reads; i++) { vals[i] = (column + i < reduction_stride) ? static_cast(row[i]) : op.init; } for (int i = 0; i < n_reads; i++) { totals[i] = op(vals[i], totals[i]); } } loop.next(lsize.y, reduce_shape, reduce_strides); } if (lsize.y > 1) { // lsize.y should be <= 8 threadgroup U shared_vals[32 * 8 * n_reads]; for (int i = 0; i < n_reads; i++) { shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (lid.y == 0) { for (int i = 0; i < n_reads; i++) { totals[i] = shared_vals[lid.x * n_reads + i]; } for (uint j = 1; j < lsize.y; j++) { for (int i = 0; i < n_reads; i++) { totals[i] = op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], totals[i]); } } } } if (lid.y == 0) { out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; } } else { for (int i = 0; column + i < reduction_stride; i++) { out[i] = totals[i]; } } } } template [[kernel]] void col_reduce_longcolumn( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& out_size [[buffer(11)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) { Op op; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + lid.x; U total = Op::init; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; r += lsize.y * gsize.z) { row = in + loop.location(); total = op(static_cast(*row), total); loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); } threadgroup U shared_vals[32 * 32]; shared_vals[lid.y * lsize.x + lid.x] = total; threadgroup_barrier(mem_flags::mem_threadgroup); if (lid.y == 0) { for (uint i = 1; i < lsize.y; i++) { total = op(total, shared_vals[i * lsize.x + lid.x]); } out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = total; } } /** * Our approach is the following simple looped approach: * 1. Each thread keeps running totals for BN / n_simdgroups outputs. * 2. Load a tile BM, BN in registers and accumulate in the running totals * 3. Move ahead by BM steps until the column axis and the non column * reductions are exhausted. * 6. If BM == 32 then transpose in SM and simd reduce the running totals. * Otherwise write in shared memory and BN threads accumulate the running * totals with a loop. * 7. Write them to the output */ template < typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN> [[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; constexpr int n_simdgroups = 8; constexpr short tgp_size = n_simdgroups * simd_size; constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_read_blocks = BN / n_reads; threadgroup U shared_vals[BN * BM]; U totals[n_reads]; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { totals[i] = Op::init; } short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y, reduce_shape, reduce_strides); for (IdxT r = offset.y; r < total; r += BM) { row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { totals[i] = op(static_cast(row[i]), totals[i]); } } else { U vals[n_reads]; for (int i = 0; i < n_reads; i++) { vals[i] = (column + i < reduction_stride) ? static_cast(row[i]) : op.init; } for (int i = 0; i < n_reads; i++) { totals[i] = op(vals[i], totals[i]); } } loop.next(BM, reduce_shape, reduce_strides); } // We can use a simd reduction to accumulate across BM so each thread writes // the partial output to SM and then each simdgroup does BN / n_simdgroups // accumulations. if (BM == 32) { constexpr int n_outputs = BN / n_simdgroups; static_assert( BM != 32 || n_outputs == n_reads, "The tile should be selected such that n_outputs == n_reads"); for (int i = 0; i < n_reads; i++) { shared_vals[offset.y * BN + offset.x + i] = totals[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); short2 out_offset(simd_group_id * n_outputs, simd_lane_id); for (int i = 0; i < n_outputs; i++) { totals[i] = op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); } // Write the output. if (simd_lane_id == 0) { IdxT out_column = BN * gid.x + out_offset.x; out += out_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; } } else { for (int i = 0; out_column + i < reduction_stride; i++) { out[i] = totals[i]; } } } } // Each thread holds n_reads partial results. We write them all out to shared // memory and threads with offset.y == 0 aggregate the columns and write the // outputs. else { short x_block = offset.x / n_reads; for (int i = 0; i < n_reads; i++) { shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (offset.y == 0) { for (int i = 0; i < n_reads; i++) { for (int j = 1; j < BM; j++) { totals[i] = op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); } } } // Write the output. if (offset.y == 0) { out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; } } else { for (int i = 0; column + i < reduction_stride; i++) { out[i] = totals[i]; } } } } } template < typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN> [[kernel]] void col_reduce_2pass( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& out_size [[buffer(11)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; constexpr int n_simdgroups = 8; constexpr short tgp_size = n_simdgroups * simd_size; constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_read_blocks = BN / n_reads; constexpr int n_outputs = BN / n_simdgroups; constexpr short outer_blocks = 32; static_assert(BM == 32, "BM should be equal to 32"); threadgroup U shared_vals[BN * BM]; U totals[n_reads]; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { totals[i] = Op::init; } short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); IdxT block_idx = full_idx / IdxT(out_size); IdxT out_idx = full_idx % IdxT(out_size); IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { totals[i] = op(static_cast(row[i]), totals[i]); } } else { U vals[n_reads]; for (int i = 0; i < n_reads; i++) { vals[i] = (column + i < reduction_stride) ? static_cast(row[i]) : op.init; } for (int i = 0; i < n_reads; i++) { totals[i] = op(vals[i], totals[i]); } } loop.next(outer_blocks * BM, reduce_shape, reduce_strides); } // We can use a simd reduction to accumulate across BM so each thread writes // the partial output to SM and then each simdgroup does BN / n_simdgroups // accumulations. for (int i = 0; i < n_reads; i++) { shared_vals[offset.y * BN + offset.x + i] = totals[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); short2 out_offset(simd_group_id * n_outputs, simd_lane_id); for (int i = 0; i < n_outputs; i++) { totals[i] = op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); } // Write the output. if (simd_lane_id == 0) { IdxT out_column = BN * gid.x + out_offset.x; out += full_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; } } else { for (int i = 0; out_column + i < reduction_stride; i++) { out[i] = totals[i]; } } } } ================================================ FILE: mlx/backend/metal/kernels/reduction/reduce_init.h ================================================ // Copyright © 2023-2024 Apple Inc. template [[kernel]] void init_reduce( device T* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) { out[tid] = Op::init; } ================================================ FILE: mlx/backend/metal/kernels/reduction/reduce_row.h ================================================ // Copyright © 2023-2024 Apple Inc. // Row reduction utilities // - `per_thread_row_reduce` collaborative partial reduction in the threadgroup // - `threadgroup_reduce` collaborative reduction in the threadgroup such that // lid.x == 0 holds the reduced value // - `thread_reduce` simple loop and reduce the row /** * The thread group collaboratively reduces across the rows with bounds * checking. In the end each thread holds a part of the reduction. */ template < typename T, typename U, typename Op, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> METAL_FUNC void per_thread_row_reduce( thread U totals[N_WRITES], const device T* inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x) { Op op; // Set up the accumulator registers for (int i = 0; i < N_WRITES; i++) { totals[i] = Op::init; } // Loop over the reduction size within thread group for (int i = 0; i < blocks; i++) { for (int j = 0; j < N_WRITES; j++) { for (int i = 0; i < N_READS; i++) { totals[j] = op(static_cast(inputs[j][i]), totals[j]); } inputs[j] += lsize_x * N_READS; } } // Separate case for the last set as we close the reduction size int index = lid_x * N_READS; if (index + N_READS <= extra) { for (int j = 0; j < N_WRITES; j++) { for (int i = 0; i < N_READS; i++) { totals[j] = op(static_cast(inputs[j][i]), totals[j]); } } } else { for (int j = 0; j < N_WRITES; j++) { for (int i = 0; index + i < extra; i++) { totals[j] = op(static_cast(inputs[j][i]), totals[j]); } } } } /** * Consecutive rows in a contiguous array. */ template < typename T, typename U, typename Op, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> METAL_FUNC void per_thread_row_reduce( thread U totals[N_WRITES], const device T* in, const constant size_t& reduction_size, int blocks, int extra, uint lsize_x, uint lid_x) { // Set up the input pointers const device T* inputs[N_WRITES]; inputs[0] = in + lid_x * N_READS; for (int i = 1; i < N_READS; i++) { inputs[i] = inputs[i - 1] + reduction_size; } per_thread_row_reduce( totals, inputs, blocks, extra, lsize_x, lid_x); } /** * Consecutive rows in an arbitrarily ordered array. */ template < typename T, typename U, typename Op, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> METAL_FUNC void per_thread_row_reduce( thread U totals[N_WRITES], const device T* in, const int64_t row_idx, int blocks, int extra, const constant int* shape, const constant int64_t* strides, const constant int& ndim, uint lsize_x, uint lid_x) { // Set up the input pointers const device T* inputs[N_WRITES]; in += lid_x * N_READS; for (int i = 0; i < N_READS; i++) { inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); } per_thread_row_reduce( totals, inputs, blocks, extra, lsize_x, lid_x); } /** * Reduce within the threadgroup. */ template < typename T, typename U, typename Op, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> METAL_FUNC void threadgroup_reduce( thread U totals[N_WRITES], threadgroup U* shared_vals, uint3 lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; // Simdgroup first for (int i = 0; i < N_WRITES; i++) { totals[i] = op.simd_reduce(totals[i]); } // Across simdgroups if (simd_per_group > 1) { if (simd_lane_id == 0) { for (int i = 0; i < N_WRITES; i++) { shared_vals[simd_group_id * N_WRITES + i] = totals[i]; } } threadgroup_barrier(mem_flags::mem_threadgroup); U values[N_WRITES]; for (int i = 0; i < N_WRITES; i++) { values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] : op.init; } for (int i = 0; i < N_WRITES; i++) { totals[i] = op.simd_reduce(values[i]); } } } template METAL_FUNC void thread_reduce(thread U& total, const device T* row, int blocks, int extra) { Op op; for (int i = 0; i < blocks; i++) { U vals[N_READS]; for (int j = 0; j < N_READS; j++) { vals[j] = row[j]; } for (int j = 0; j < N_READS; j++) { total = op(vals[j], total); } row += N_READS; } for (int i = 0; i < extra; i++) { total = op(*row++, total); } } // Reduction kernels // - `row_reduce_small` depending on the non-row reductions and row size it // either just loops over everything or a simd collaboratively reduces the // non_row reductions. In the first case one thread is responsible for one // output on the 2nd one simd is responsible for one output. // - `row_reduce_simple` simple contiguous row reduction // - `row_reduce_looped` simply loop and reduce each row for each non-row // reduction. One threadgroup is responsible for one output. template < typename T, typename U, typename Op, typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant int64_t& row_size [[buffer(2)]], const constant int64_t& non_row_reductions [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 tid [[thread_position_in_grid]], uint3 tsize [[threads_per_grid]]) { Op op; U total_val = Op::init; LoopedElemToLoc 2)> loop(reduce_ndim); // Precompute some row reduction numbers const device T* row; int blocks = IdxT(row_size) / N_READS; int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(reduce_shape, reduce_strides); } out[out_idx] = total_val; } else { // Collaboratively reduce over non_row_reductions in the simdgroup. Each // thread reduces every 32nd row and then a simple simd reduce. IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); in += elem_to_loc(out_idx, shape, strides, ndim); loop.next(simd_lane_id, reduce_shape, reduce_strides); for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(simd_size, reduce_shape, reduce_strides); } total_val = op.simd_reduce(total_val); if (simd_lane_id == 0) { out[out_idx] = total_val; } } } template < typename T, typename U, typename Op, typename IdxT = int64_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> [[kernel]] void row_reduce_simple( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant int64_t& out_size [[buffer(3)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { threadgroup U shared_vals[simd_size * N_WRITES]; U totals[N_WRITES]; // Move to the row IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); if (out_idx + N_WRITES > out_size) { out_idx = out_size - N_WRITES; } in += out_idx * IdxT(reduction_size); out += out_idx; // Each thread reduces across the row int blocks = IdxT(reduction_size) / (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS); per_thread_row_reduce( totals, in, reduction_size, blocks, extra, lsize.x, lid.x); // Reduce across the threadgroup threadgroup_reduce( totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); // Write the output if (lid.x == 0) { for (int i = 0; i < N_WRITES; i++) { out[i] = totals[i]; } } } template < typename T, typename U, typename Op, typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant int64_t& row_size [[buffer(2)]], const constant int64_t& non_row_reductions [[buffer(3)]], const constant int* shape [[buffer(4)]], const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; threadgroup U shared_vals[simd_size]; U total = Op::init; IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // needs a small refactor. in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; int blocks = IdxT(row_size) / (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS); for (IdxT i = 0; i < non_row_reductions; i++) { row = in + loop.location(); // Each thread reduces across the row U row_total; per_thread_row_reduce( &row_total, &row, blocks, extra, lsize.x, lid.x); // Aggregate across rows total = op(total, row_total); loop.next(reduce_shape, reduce_strides); } // Reduce across the threadgroup threadgroup_reduce( &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); // Write the output if (lid.x == 0) { out[out_idx] = total; } } ================================================ FILE: mlx/backend/metal/kernels/rms_norm.metal ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/backend/metal/kernels/utils.h" using namespace metal; constant bool has_w [[function_constant(20)]]; template [[kernel]] void rms_single_row( const device T* x, const device T* w, device T* out, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; threadgroup float local_inv_mean[1]; threadgroup float local_sums[SIMD_SIZE]; float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = x[i]; acc += xi * xi; } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { float xi = x[i]; acc += xi * xi; } } } acc = simd_sum(acc); // Initialize shared memory if (simd_group_id == 0) { local_sums[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write simd accumulations into shared memory if (simd_lane_id == 0) { local_sums[simd_group_id] = acc; } threadgroup_barrier(mem_flags::mem_threadgroup); // Accumulate over simd groups if (simd_group_id == 0) { acc = simd_sum(local_sums[simd_lane_id]); if (simd_lane_id == 0) { local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); } } threadgroup_barrier(mem_flags::mem_threadgroup); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); } } } } template [[kernel]] void rms_looped( const device T* x, const device T* w, device T* out, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int SIMD_SIZE = 32; threadgroup float local_inv_mean[1]; threadgroup float local_sums[SIMD_SIZE]; float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = x[i + r]; acc += xi * xi; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = x[i + r]; acc += xi * xi; } } } } acc = simd_sum(acc); // Initialize shared memory if (simd_group_id == 0) { local_sums[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write simd accumulations into shared memory if (simd_lane_id == 0) { local_sums[simd_group_id] = acc; } threadgroup_barrier(mem_flags::mem_threadgroup); // Accumulate over simd groups if (simd_group_id == 0) { acc = simd_sum(local_sums[simd_lane_id]); if (simd_lane_id == 0) { local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); } } threadgroup_barrier(mem_flags::mem_threadgroup); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[r + i] = w[w_stride * (i + r)] * static_cast(x[r + i] * local_inv_mean[0]); } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { out[r + i] = w[w_stride * (i + r)] * static_cast(x[r + i] * local_inv_mean[0]); } } } } } template [[kernel]] void vjp_rms_single_row( const device T* x, const device T* w, const device T* g, device T* gx, device T* gw, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the computation and accumulators float thread_x[N_READS]; float thread_w[N_READS]; float thread_g[N_READS]; float sumx2 = 0; float sumgwx = 0; // Allocate shared memory to implement the reduction constexpr int SIMD_SIZE = 32; threadgroup float local_sumx2[SIMD_SIZE]; threadgroup float local_sumgwx[SIMD_SIZE]; threadgroup float local_normalizer[1]; threadgroup float local_meangwx[1]; // Read and accumulate locally if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; thread_w[i] = w[w_stride * i]; thread_g[i] = g[i]; sumx2 += thread_x[i] * thread_x[i]; sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { thread_x[i] = x[i]; thread_w[i] = w[w_stride * i]; thread_g[i] = g[i]; sumx2 += thread_x[i] * thread_x[i]; sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; } } } // Accumulate across threads sumx2 = simd_sum(sumx2); sumgwx = simd_sum(sumgwx); if (simd_group_id == 0) { local_sumx2[simd_lane_id] = 0; local_sumgwx[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == 0) { local_sumx2[simd_group_id] = sumx2; local_sumgwx[simd_group_id] = sumgwx; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { sumx2 = simd_sum(local_sumx2[simd_lane_id]); sumgwx = simd_sum(local_sumgwx[simd_lane_id]); if (simd_lane_id == 0) { local_meangwx[0] = sumgwx / axis_size; local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); } } threadgroup_barrier(mem_flags::mem_threadgroup); float meangwx = local_meangwx[0]; float normalizer = local_normalizer[0]; float normalizer3 = normalizer * normalizer * normalizer; // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); } } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); } } } } } template [[kernel]] void vjp_rms_looped( const device T* x, const device T* w, const device T* g, device T* gx, device T* gw, constant float& eps, constant uint& axis_size, constant uint& w_stride, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the accumulators float sumx2 = 0; float sumgwx = 0; // Allocate shared memory to implement the reduction constexpr int SIMD_SIZE = 32; threadgroup float local_sumx2[SIMD_SIZE]; threadgroup float local_sumgwx[SIMD_SIZE]; threadgroup float local_normalizer[1]; threadgroup float local_meangwx[1]; // Read and accumulate locally for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = x[i + r]; float wi = w[w_stride * (i + r)]; float gi = g[i + r]; sumx2 += xi * xi; sumgwx += xi * wi * gi; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = x[i + r]; float wi = w[w_stride * (i + r)]; float gi = g[i + r]; sumx2 += xi * xi; sumgwx += xi * wi * gi; } } } } // Accumulate across threads sumx2 = simd_sum(sumx2); sumgwx = simd_sum(sumgwx); if (simd_group_id == 0) { local_sumx2[simd_lane_id] = 0; local_sumgwx[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == 0) { local_sumx2[simd_group_id] = sumx2; local_sumgwx[simd_group_id] = sumgwx; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { sumx2 = simd_sum(local_sumx2[simd_lane_id]); sumgwx = simd_sum(local_sumgwx[simd_lane_id]); if (simd_lane_id == 0) { local_meangwx[0] = sumgwx / axis_size; local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); } } threadgroup_barrier(mem_flags::mem_threadgroup); float meangwx = local_meangwx[0]; float normalizer = local_normalizer[0]; float normalizer3 = normalizer * normalizer * normalizer; // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = x[i + r]; float wi = w[w_stride * (i + r)]; float gi = g[i + r]; gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); if (has_w) { gw[i + r] = static_cast(gi * xi * normalizer); } } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = x[i + r]; float wi = w[w_stride * (i + r)]; float gi = g[i + r]; gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); if (has_w) { gw[i + r] = static_cast(gi * xi * normalizer); } } } } } } // clang-format off #define instantiate_rms(name, itype) \ instantiate_kernel("rms" #name, rms_single_row, itype) \ instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \ instantiate_kernel("rms_looped" #name, rms_looped, itype) \ instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype) instantiate_rms(float32, float) instantiate_rms(float16, half) instantiate_rms(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/rope.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" constant bool forward [[function_constant(1)]]; constant bool traditional [[function_constant(2)]]; constant bool hs_transpose [[function_constant(3)]]; template void rope_single_impl( const device T* in, device T* out, constant const int& offset, const float inv_freq, constant const float& scale, constant const int64_t& stride, uint2 pos, uint2 grid) { float L = scale * static_cast(offset); // Compute costheta, sintheta float theta = L * inv_freq; float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); // Compute the input and output indices uint index_1, index_2; if (traditional) { index_1 = 2 * pos.x + pos.y * stride; index_2 = index_1 + 1; } else { index_1 = pos.x + pos.y * stride; index_2 = index_1 + grid.x; } // Read and write the output float x1 = static_cast(in[index_1]); float x2 = static_cast(in[index_2]); float rx1; float rx2; if (forward) { rx1 = x1 * costheta - x2 * sintheta; rx2 = x1 * sintheta + x2 * costheta; } else { rx1 = x2 * sintheta + x1 * costheta; rx2 = x2 * costheta - x1 * sintheta; } out[index_1] = static_cast(rx1); out[index_2] = static_cast(rx2); } template [[kernel]] void rope_single( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, constant const int64_t& stride, constant const float& base [[buffer(10)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { float d = static_cast(pos.x) / static_cast(grid.x); float inv_freq = metal::exp2(-d * base); rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); } template [[kernel]] void rope_single_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, constant const int64_t& stride, const device float* freqs [[buffer(10)]], constant const int64_t& freq_stride [[buffer(11)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); } template void rope_impl( const device T* in, device T* out, const device int* offset, const float inv_freq, constant const float& scale, constant const int64_t strides[3], constant const int64_t out_strides[3], constant const int64_t& offset_stride, constant const int& n_head, uint3 pos, uint3 grid) { auto n_head_up = N * ((n_head + N - 1) / N); auto head_idx = static_cast((pos.z * N) % n_head_up); auto batch_idx = (pos.z * N) / n_head_up; auto batch_offset = offset[batch_idx * offset_stride]; float L = scale * static_cast(pos.y + batch_offset); auto mat_idx = batch_idx * n_head + head_idx; // Compute costheta, sintheta float theta = L * inv_freq; float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); // Compute the input and output indices IdxT in_index_1; if (hs_transpose) { IdxT batch_stride = grid.y * IdxT(strides[1]); in_index_1 = batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0]; } else { in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]); } IdxT in_index_2; IdxT out_index_1 = pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]); IdxT out_index_2; if (traditional) { out_index_1 += 2 * pos.x * IdxT(out_strides[2]); out_index_2 = out_index_1 + 1; in_index_1 += 2 * pos.x * IdxT(strides[2]); in_index_2 = in_index_1 + IdxT(strides[2]); } else { out_index_1 += pos.x * IdxT(out_strides[2]); out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]); in_index_1 += pos.x * IdxT(strides[2]); in_index_2 = in_index_1 + grid.x * IdxT(strides[2]); } for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); float rx1; float rx2; if (forward) { rx1 = x1 * costheta - x2 * sintheta; rx2 = x1 * sintheta + x2 * costheta; } else { rx1 = x2 * sintheta + x1 * costheta; rx2 = x2 * costheta - x1 * sintheta; } out[out_index_1] = static_cast(rx1); out[out_index_2] = static_cast(rx2); in_index_1 += IdxT(strides[0]); in_index_2 += IdxT(strides[0]); out_index_1 += IdxT(out_strides[0]); out_index_2 += IdxT(out_strides[0]); } } template [[kernel]] void rope( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const device int* offset, constant const float& scale, constant const int64_t strides[3], constant const int64_t out_strides[3], constant const int64_t& offset_stride, constant const int& n_head, constant const float& base [[buffer(10)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { float d = static_cast(pos.x) / static_cast(grid.x); float inv_freq = metal::exp2(-d * base); rope_impl( in, out, offset, inv_freq, scale, strides, out_strides, offset_stride, n_head, pos, grid); } template [[kernel]] void rope_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const device int* offset, constant const float& scale, constant const int64_t strides[3], constant const int64_t out_strides[3], constant const int64_t& offset_stride, constant const int& n_head, const device float* freqs [[buffer(10)]], constant const int64_t& freq_stride [[buffer(11)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); rope_impl( in, out, offset, inv_freq, scale, strides, out_strides, offset_stride, n_head, pos, grid); } // clang-format off #define instantiate_rope_g(name, type) \ instantiate_kernel("rope_" #name, rope, type, int32_t) \ instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \ instantiate_kernel("rope_large_" #name, rope, type, int64_t) \ instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t) #define instantiate_rope_s(name, type) \ instantiate_kernel("rope_single_" #name, rope_single, type) \ instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type) #define instantiate_rope(name, type) \ instantiate_rope_s(name, type) \ instantiate_rope_g(name, type) instantiate_rope(float16, half) instantiate_rope(bfloat16, bfloat16_t) instantiate_rope(float32, float) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/scaled_dot_product_attention.metal ================================================ #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/sdpa_vector.h" using namespace metal; // SDPA vector instantiations #define instantiate_sdpa_vector_aggregation(type, value_dim) \ instantiate_kernel( \ "sdpa_vector_2pass_2_" #type "_" #value_dim, \ sdpa_vector_2pass_2, \ type, \ value_dim) #define instantiate_sdpa_vector(type, qk_dim, value_dim) \ instantiate_kernel( \ "sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \ sdpa_vector, \ type, \ qk_dim, \ value_dim) \ instantiate_kernel( \ "sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \ sdpa_vector_2pass_1, \ type, \ qk_dim, \ value_dim) #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/scan.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/binary_ops.h" #define DEFINE_SIMD_SCAN() \ template = true> \ T simd_scan(T val) { \ return simd_scan_impl(val); \ } \ \ template = true> \ T simd_scan(T val) { \ for (int i = 1; i <= 16; i *= 2) { \ val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ } \ return val; \ } #define DEFINE_SIMD_EXCLUSIVE_SCAN() \ template = true> \ T simd_exclusive_scan(T val) { \ return simd_exclusive_scan_impl(val); \ } \ \ template = true> \ T simd_exclusive_scan(T val) { \ val = simd_scan(val); \ return simd_shuffle_and_fill_up(val, init, 1); \ } template struct CumSum { DEFINE_SIMD_SCAN() DEFINE_SIMD_EXCLUSIVE_SCAN() static constexpr constant U init = static_cast(0); template U operator()(U a, T b) { return a + b; } U simd_scan_impl(U x) { return simd_prefix_inclusive_sum(x); } U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_sum(x); } }; template struct CumProd { DEFINE_SIMD_SCAN() DEFINE_SIMD_EXCLUSIVE_SCAN() static constexpr constant U init = static_cast(1.0f); template U operator()(U a, T b) { return a * b; } U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_product(x); } }; template <> struct CumProd { static constexpr constant bool init = true; template bool operator()(bool a, T b) { return a & static_cast(b); } bool simd_scan(bool x) { for (int i = 1; i <= 16; i *= 2) { bool other = simd_shuffle_and_fill_up(x, init, i); x &= other; } return x; } bool simd_exclusive_scan(bool x) { x = simd_scan(x); return simd_shuffle_and_fill_up(x, init, 1); } }; template struct CumMax { static constexpr constant U init = Limits::min; template U operator()(U a, T b) { return (a >= b) ? a : b; } U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { U other = simd_shuffle_and_fill_up(x, init, i); x = (x >= other) ? x : other; } return x; } U simd_exclusive_scan(U x) { x = simd_scan(x); return simd_shuffle_and_fill_up(x, init, 1); } }; template struct CumMin { static constexpr constant U init = Limits::max; template U operator()(U a, T b) { return (a <= b) ? a : b; } U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { U other = simd_shuffle_and_fill_up(x, init, i); x = (x <= other) ? x : other; } return x; } U simd_exclusive_scan(U x) { x = simd_scan(x); return simd_shuffle_and_fill_up(x, init, 1); } }; template struct CumLogaddexp { static constexpr constant U init = Limits::min; template U operator()(U a, T b) { return LogAddExp{}(a, static_cast(b)); } U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { U other = simd_shuffle_and_fill_up(x, init, i); x = LogAddExp{}(x, other); } return x; } U simd_exclusive_scan(U x) { x = simd_scan(x); return simd_shuffle_and_fill_up(x, init, 1); } }; template inline void load_unsafe(U values[N_READS], const device T* input) { if (reverse) { for (int i = 0; i < N_READS; i++) { values[N_READS - i - 1] = input[i]; } } else { for (int i = 0; i < N_READS; i++) { values[i] = input[i]; } } } template inline void load_safe( U values[N_READS], const device T* input, int start, int total, U init) { if (reverse) { for (int i = 0; i < N_READS; i++) { values[N_READS - i - 1] = (start + N_READS - i - 1 < total) ? input[i] : init; } } else { for (int i = 0; i < N_READS; i++) { values[i] = (start + i < total) ? input[i] : init; } } } template inline void write_unsafe(U values[N_READS], device U* out) { if (reverse) { for (int i = 0; i < N_READS; i++) { out[i] = values[N_READS - i - 1]; } } else { for (int i = 0; i < N_READS; i++) { out[i] = values[i]; } } } template inline void write_safe(U values[N_READS], device U* out, int start, int total) { if (reverse) { for (int i = 0; i < N_READS; i++) { if (start + N_READS - i - 1 < total) { out[i] = values[N_READS - i - 1]; } } } else { for (int i = 0; i < N_READS; i++) { if (start + i < total) { out[i] = values[i]; } } } } template < typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse> [[kernel]] void contiguous_scan( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int simd_size = 32; Op op; // Position the pointers size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; in += offset; out += offset; // Compute the number of simd_groups uint simd_groups = lsize.x / simd_size; // Allocate memory U prefix = Op::init; U values[N_READS]; threadgroup U simdgroup_sums[32]; // Loop over the reduced axis in blocks of size ceildiv(axis_size, // N_READS*lsize) // Read block // Compute inclusive scan of the block // Compute inclusive scan per thread // Compute exclusive scan of thread sums in simdgroup // Write simdgroup sums in SM // Compute exclusive scan of simdgroup sums // Compute the output by scanning prefix, prev_simdgroup, prev_thread, // value // Write block for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Compute the block offset uint offset = r * lsize.x * N_READS + lid.x * N_READS; // Read the values if (reverse) { if ((offset + N_READS) < axis_size) { load_unsafe( values, in + axis_size - offset - N_READS); } else { load_safe( values, in + axis_size - offset - N_READS, offset, axis_size, Op::init); } } else { if ((offset + N_READS) < axis_size) { load_unsafe(values, in + offset); } else { load_safe( values, in + offset, offset, axis_size, Op::init); } } // Compute an inclusive scan per thread for (int i = 1; i < N_READS; i++) { values[i] = op(values[i], values[i - 1]); } // Compute exclusive scan of thread sums U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); // Write simdgroup_sums to SM threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == simd_size - 1) { simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); } threadgroup_barrier(mem_flags::mem_threadgroup); // Compute exclusive scan of simdgroup_sums if (simd_group_id == 0) { U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); simdgroup_sums[simd_lane_id] = prev_simdgroup; } threadgroup_barrier(mem_flags::mem_threadgroup); // Compute the output for (int i = 0; i < N_READS; i++) { values[i] = op(values[i], prefix); values[i] = op(values[i], simdgroup_sums[simd_group_id]); values[i] = op(values[i], prev_thread); } // Write the values if (reverse) { if (inclusive) { if ((offset + N_READS) < axis_size) { write_unsafe( values, out + axis_size - offset - N_READS); } else { write_safe( values, out + axis_size - offset - N_READS, offset, axis_size); } } else { if (lid.x == 0 && offset == 0) { out[axis_size - 1] = Op::init; } if ((offset + N_READS + 1) < axis_size) { write_unsafe( values, out + axis_size - offset - 1 - N_READS); } else { write_safe( values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size); } } } else { if (inclusive) { if ((offset + N_READS) < axis_size) { write_unsafe(values, out + offset); } else { write_safe( values, out + offset, offset, axis_size); } } else { if (lid.x == 0 && offset == 0) { out[0] = Op::init; } if ((offset + N_READS + 1) < axis_size) { write_unsafe(values, out + offset + 1); } else { write_safe( values, out + offset + 1, offset + 1, axis_size); } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Share the prefix if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { simdgroup_sums[0] = values[N_READS - 1]; } threadgroup_barrier(mem_flags::mem_threadgroup); prefix = simdgroup_sums[0]; } } template < typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse> [[kernel]] void strided_scan( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], const constant size_t& stride [[buffer(3)]], const constant size_t& stride_blocks [[buffer(4)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { constexpr int simd_size = 32; constexpr int BM = 32; constexpr int BN = 32; constexpr int BN_pad = 32 + 16 / sizeof(U); constexpr int n_simds = BN / N_READS; constexpr int n_scans = BN / n_simds; Op op; threadgroup U read_buffer[BM * BN_pad]; U values[n_scans]; U prefix[n_scans]; for (int i = 0; i < n_scans; i++) { prefix[i] = Op::init; } // Compute offsets size_t full_gid = gid.y + gsize.y * size_t(gid.z); size_t offset = full_gid / stride_blocks * axis_size * stride; size_t global_index_x = full_gid % stride_blocks * BN; uint read_offset_y = (lid.x * N_READS) / BN; uint read_offset_x = (lid.x * N_READS) % BN; uint scan_offset_y = simd_lane_id; uint scan_offset_x = simd_group_id * n_scans; uint stride_limit = stride - global_index_x; in += offset + global_index_x + read_offset_x; out += offset + global_index_x + read_offset_x; threadgroup U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; threadgroup U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; for (uint j = 0; j < axis_size; j += BM) { // Calculate the indices for the current thread uint index_y = j + read_offset_y; uint check_index_y = index_y; if (reverse) { index_y = axis_size - 1 - index_y; } // Read in SM threadgroup_barrier(mem_flags::mem_threadgroup); if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { read_into[i] = in[index_y * stride + i]; } } else { for (int i = 0; i < N_READS; i++) { if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { read_into[i] = in[index_y * stride + i]; } else { read_into[i] = Op::init; } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Read strided into registers for (int i = 0; i < n_scans; i++) { values[i] = read_from[i]; } simdgroup_barrier(mem_flags::mem_threadgroup); // Perform the scan for (int i = 0; i < n_scans; i++) { values[i] = op.simd_scan(values[i]); values[i] = op(values[i], prefix[i]); prefix[i] = simd_shuffle(values[i], simd_size - 1); } // Write to SM for (int i = 0; i < n_scans; i++) { read_from[i] = values[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write to device memory if (!inclusive) { if (check_index_y == 0) { if ((read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { out[index_y * stride + i] = Op::init; } } else { for (int i = 0; i < N_READS; i++) { if ((read_offset_x + i) < stride_limit) { out[index_y * stride + i] = Op::init; } } } } if (reverse) { index_y -= 1; check_index_y += 1; } else { index_y += 1; check_index_y += 1; } } if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { out[index_y * stride + i] = read_into[i]; } } else { for (int i = 0; i < N_READS; i++) { if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { out[index_y * stride + i] = read_into[i]; } } } } } ================================================ FILE: mlx/backend/metal/kernels/scan.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include // clang-format off using namespace metal; #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/scan.h" #define instantiate_contiguous_scan( \ name, itype, otype, op, inclusive, reverse, nreads) \ template [[host_name("contig_scan_" #name)]] [[kernel]] void \ contiguous_scan, nreads, inclusive, reverse>( \ const device itype* in [[buffer(0)]], \ device otype* out [[buffer(1)]], \ const constant size_t& axis_size [[buffer(2)]], \ uint3 gid [[threadgroup_position_in_grid]], \ uint3 gsize [[threadgroups_per_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint3 lsize [[threads_per_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); #define instantiate_strided_scan( \ name, itype, otype, op, inclusive, reverse, nreads) \ template [[host_name("strided_scan_" #name)]] [[kernel]] void \ strided_scan, nreads, inclusive, reverse>( \ const device itype* in [[buffer(0)]], \ device otype* out [[buffer(1)]], \ const constant size_t& axis_size [[buffer(2)]], \ const constant size_t& stride [[buffer(3)]], \ const constant size_t& stride_blocks [[buffer(4)]], \ uint3 gid [[threadgroup_position_in_grid]], \ uint3 gsize [[threadgroups_per_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); #define instantiate_scan_helper(name, itype, otype, op, nreads) \ instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \ instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4) instantiate_scan_helper(sum_bool__uint32, bool, uint32_t, CumSum, 4) instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4) instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2) instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4) instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4) instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4) instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2) instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4) instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4) instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4) instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2) instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4) instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4) instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4) instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4) instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2) instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4) instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4) instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4) instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2) instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4) instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4) instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4) instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2) instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4) instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4) instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4) instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4) instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2) instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4) instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4) instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4) instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2) instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4) instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4) instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4) instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2) instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4) instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4) instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4) instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4) instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2) instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4) instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4) instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4) instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/sdpa_vector.h ================================================ // Copyright © 2024 Apple Inc. #include using namespace metal; constant bool has_mask [[function_constant(20)]]; constant bool query_transposed [[function_constant(21)]]; constant bool do_causal [[function_constant(22)]]; constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; constant int blocks [[function_constant(26)]]; template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device T* out [[buffer(3)]], const constant int& gqa_factor [[buffer(4)]], const constant int& N [[buffer(5)]], const constant size_t& k_head_stride [[buffer(6)]], const constant size_t& k_seq_stride [[buffer(7)]], const constant size_t& v_head_stride [[buffer(8)]], const constant size_t& v_seq_stride [[buffer(9)]], const constant float& scale [[buffer(10)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]], const constant int& mask_kv_seq_stride [[buffer(13), function_constant(has_mask)]], const constant int& mask_q_seq_stride [[buffer(14), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(15), function_constant(has_mask)]], const device T* sinks [[buffer(16), function_constant(has_sinks)]], const constant int& num_q_heads [[buffer(17), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; int inner_k_stride = BN * int(k_seq_stride); int inner_v_stride = BN * int(v_seq_stride); typedef float U; thread U q[qk_per_thread]; thread U k[qk_per_thread]; thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; threadgroup U max_scores[BN]; threadgroup U sum_exp_scores[BN]; // Adjust positions const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = q_batch_head_idx / gqa_factor; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } for (int i = 0; i < v_per_thread; i++) { o[i] = 0; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && simd_gid == 0) { max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); sum_exp_score = 1; } // For each key for (int i = simd_gid; i < N; i += BN) { bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key for (int j = 0; j < qk_per_thread; j++) { k[j] = keys[j]; } // Compute the i-th score U score = 0; for (int j = 0; j < qk_per_thread; j++) { score += q[j] * k[j]; } score = simd_sum(score); if (float_mask) { score += static_cast(fmask[0]); } // Update the accumulators U new_max = max(max_score, score); U factor = fast::exp(max_score - new_max); U exp_score = fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator for (int j = 0; j < v_per_thread; j++) { o[j] = o[j] * factor + exp_score * values[j]; } } // Move the pointers to the next kv keys += inner_k_stride; values += inner_v_stride; if (bool_mask) { bmask += BN * mask_kv_seq_stride; } if (float_mask) { fmask += BN * mask_kv_seq_stride; } } // Each thread has a partial part of the output so we need to combine them. // First let's communicate the max and sum_exp if (simd_lid == 0) { max_scores[simd_gid] = max_score; sum_exp_scores[simd_gid] = sum_exp_score; } threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; U new_max = simd_max(max_score); U factor = fast::exp(max_score - new_max); sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); // Now we need to aggregate all the outputs for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } // And write the output if (simd_lid == 0) { for (int i = 0; i < v_per_thread; i++) { out[i] = static_cast(o[i]); } } } template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device T* out [[buffer(3)]], device float* sums [[buffer(4)]], device float* maxs [[buffer(5)]], const constant int& N [[buffer(7)]], const constant size_t& k_head_stride [[buffer(8)]], const constant size_t& k_seq_stride [[buffer(9)]], const constant size_t& v_head_stride [[buffer(10)]], const constant size_t& v_seq_stride [[buffer(11)]], const constant float& scale [[buffer(12)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]], const constant int& mask_kv_seq_stride [[buffer(15), function_constant(has_mask)]], const constant int& mask_q_seq_stride [[buffer(16), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], const device T* sinks [[buffer(18), function_constant(has_sinks)]], uint3 tptg [[threads_per_threadgroup]], uint3 tidtg [[thread_position_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; typedef float U; thread U q[qk_per_thread]; thread U o[v_per_thread] = {0}; // Adjust positions const int kv_head_idx = tid.x; const int batch_idx = tid.y; const int block_idx = tid.z; const int gqa_factor = tptg.y; const int q_seq_len = tptg.z; const int q_seq_idx = tidtg.z; const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y; const int num_kv_heads = tpg.x; const int num_q_heads = num_kv_heads * gqa_factor; const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx); const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; const int q_offset = query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride + simd_lid * qk_per_thread; values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; // Read the query for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && block_idx == 0) { max_score = static_cast(sinks[q_head_idx]); sum_exp_score = 1; } // For each key for (int i = block_idx; i < N; i += blocks) { bool use_key = true; if (do_causal) { use_key = i <= (N - q_seq_len + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Compute the i-th score U score = 0; for (int i = 0; i < qk_per_thread; i++) { score += q[i] * keys[i]; } score = simd_sum(score); if (float_mask) { score += fmask[0]; } // Update the accumulators U new_max = max(max_score, score); U factor = fast::exp(max_score - new_max); U exp_score = fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator for (int i = 0; i < v_per_thread; i++) { o[i] = o[i] * factor + exp_score * values[i]; } } // Move the pointers to the next kv keys += blocks * int(k_seq_stride); values += blocks * int(v_seq_stride); if (bool_mask) { bmask += blocks * mask_kv_seq_stride; } if (float_mask) { fmask += blocks * mask_kv_seq_stride; } } // Write the sum and max and outputs if (simd_lid == 0) { sums[0] = sum_exp_score; maxs[0] = max_score; } for (int i = 0; i < v_per_thread; i++) { out[i] = static_cast(o[i]); } } template [[kernel]] void sdpa_vector_2pass_2( const device T* partials [[buffer(0)]], const device float* sums [[buffer(1)]], const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], const constant int& blocks [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; typedef float U; thread U o[elem_per_thread] = {0}; threadgroup U outputs[BN * BD]; // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; const int q_offset = head_idx * tpg.y + q_seq_idx; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; out += q_offset * D + simd_gid * elem_per_thread; // Set defaults U sum_exp_score = 0.0; U max_score = Limits::finite_min; // Reduce the max for (int b = 0; b < blocks / BN; ++b) { max_score = max(max_score, maxs[simd_lid + BN * b]); } max_score = simd_max(max_score); // Reduce the d for (int b = 0; b < blocks / BN; ++b) { U factor = fast::exp(maxs[simd_lid + BN * b] - max_score); sum_exp_score += factor * sums[simd_lid + BN * b]; } sum_exp_score = simd_sum(sum_exp_score); // Reduce the sum exp and partials for (int b = 0; b < blocks / BN; ++b) { U factor = fast::exp(maxs[simd_gid] - max_score); // Update the output accumulator for (int i = 0; i < elem_per_thread; i++) { o[i] += factor * static_cast(partials[i]); } maxs += BN; sums += BN; partials += BN * D; } // Use shared memory to transpose and reduce the final block for (int i = 0; i < elem_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } // And write the output if (simd_lid == 0) { for (int i = 0; i < elem_per_thread; i++) { out[i] = static_cast(o[i]); } } } ================================================ FILE: mlx/backend/metal/kernels/softmax.h ================================================ // Copyright © 2023-2024 Apple Inc. template inline T softmax_exp(T x) { // Softmax doesn't need high precision exponential cause x is gonna be in // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). return fast::exp(x); } template [[kernel]] void softmax_single_row( const device T* in, device T* out, constant int& axis_size, uint gid [[threadgroup_position_in_grid]], uint _lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { int lid = _lid; constexpr int SIMD_SIZE = 32; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; AccT ld[N_READS]; in += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { ld[i] = AccT(in[i]); } } else { for (int i = 0; i < N_READS; i++) { ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; } } if (simd_group_id == 0) { local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min; for (int i = 0; i < N_READS; i++) { maxval = (maxval < ld[i]) ? ld[i] : maxval; } maxval = simd_max(maxval); if (simd_lane_id == 0) { local_max[simd_group_id] = maxval; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { maxval = simd_max(local_max[simd_lane_id]); if (simd_lane_id == 0) { local_max[0] = maxval; } } threadgroup_barrier(mem_flags::mem_threadgroup); maxval = local_max[0]; // Compute exp(x_i - maxval) and store the partial sums in local_normalizer AccT normalizer = 0; for (int i = 0; i < N_READS; i++) { AccT exp_x = softmax_exp(ld[i] - maxval); ld[i] = exp_x; normalizer += exp_x; } normalizer = simd_sum(normalizer); if (simd_lane_id == 0) { local_normalizer[simd_group_id] = normalizer; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id == 0) { normalizer = simd_sum(local_normalizer[simd_lane_id]); if (simd_lane_id == 0) { local_normalizer[0] = normalizer; } } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = 1 / local_normalizer[0]; // Normalize and write to the output out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[i] = T(ld[i] * normalizer); } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { out[i] = T(ld[i] * normalizer); } } } } template [[kernel]] void softmax_looped( const device T* in, device T* out, constant int& axis_size, uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { in += gid * size_t(axis_size); constexpr int SIMD_SIZE = 32; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; // Get the max and the normalizer in one go AccT prevmax; AccT maxval = Limits::finite_min; AccT normalizer = 0; for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; AccT vals[N_READS]; if (offset + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { vals[i] = AccT(in[offset + i]); } } else { for (int i = 0; i < N_READS; i++) { vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; for (int i = 0; i < N_READS; i++) { maxval = (maxval < vals[i]) ? vals[i] : maxval; } normalizer *= softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { normalizer += softmax_exp(vals[i] - maxval); } } // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * // lsize) parts. We need to combine them. // 1. We start by finding the max across simd groups // 2. We then change the partial normalizers to account for a possible // change in max // 3. We sum all normalizers prevmax = maxval; maxval = simd_max(maxval); normalizer *= softmax_exp(prevmax - maxval); normalizer = simd_sum(normalizer); // Now the normalizer and max value is correct for each simdgroup. We write // them shared memory and combine them. prevmax = maxval; if (simd_lane_id == 0) { local_max[simd_group_id] = maxval; } threadgroup_barrier(mem_flags::mem_threadgroup); maxval = simd_max(local_max[simd_lane_id]); normalizer *= softmax_exp(prevmax - maxval); if (simd_lane_id == 0) { local_normalizer[simd_group_id] = normalizer; } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); normalizer = 1 / normalizer; // Finally given the normalizer and max value we can directly write the // softmax output out += gid * size_t(axis_size); for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; if (offset + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); } } else { for (int i = 0; i < N_READS; i++) { if (offset + i < axis_size) { out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); } } } } } ================================================ FILE: mlx/backend/metal/kernels/softmax.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include #include using namespace metal; // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/softmax.h" #define instantiate_softmax(name, itype) \ instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \ instantiate_kernel("looped_softmax_" #name, softmax_looped, itype) #define instantiate_softmax_precise(name, itype) \ instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \ instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float) instantiate_softmax(float32, float) instantiate_softmax(float16, half) instantiate_softmax(bfloat16, bfloat16_t) instantiate_softmax_precise(float16, half) instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/sort.h ================================================ // Copyright © 2023-2024 Apple Inc. #define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") using namespace metal; // Based on GPU merge sort algorithm at // https://github.com/NVIDIA/cccl/tree/main/cub/cub /////////////////////////////////////////////////////////////////////////////// // Thread-level sort /////////////////////////////////////////////////////////////////////////////// template METAL_FUNC void thread_swap(thread T& a, thread T& b) { T w = a; a = b; b = w; } template struct Init { static constexpr constant T v = Limits::max; }; template struct Init>> { static constexpr constant T v = metal::numeric_limits::quiet_NaN(); }; template struct LessThan { static constexpr constant T init = Init::v; METAL_FUNC bool operator()(T a, T b) const { if constexpr ( metal::is_floating_point_v || metal::is_same_v) { bool an = isnan(a); bool bn = isnan(b); if (an | bn) { return (!an) & bn; } } return a < b; } }; template < typename ValT, typename IdxT, bool ARG_SORT, short N_PER_THREAD, typename CompareOp> struct ThreadSort { static METAL_FUNC void sort( thread ValT (&vals)[N_PER_THREAD], thread IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; MLX_MTL_LOOP_UNROLL for (short i = 0; i < N_PER_THREAD; ++i) { MLX_MTL_LOOP_UNROLL for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); if (ARG_SORT) { thread_swap(idxs[j + 1], idxs[j]); } } } } } }; /////////////////////////////////////////////////////////////////////////////// // Threadgroup-level sort /////////////////////////////////////////////////////////////////////////////// template < typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp> struct BlockMergeSort { using thread_sort_t = ThreadSort; static METAL_FUNC int merge_partition( const threadgroup ValT* As, const threadgroup ValT* Bs, short A_sz, short B_sz, short sort_md) { CompareOp op; short A_st = max(0, sort_md - B_sz); short A_ed = min(sort_md, A_sz); while (A_st < A_ed) { short md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; if (op(b, a)) { A_ed = md; } else { A_st = md + 1; } } return A_ed; } static METAL_FUNC void merge_step( const threadgroup ValT* As, const threadgroup ValT* Bs, const threadgroup IdxT* As_idx, const threadgroup IdxT* Bs_idx, short A_sz, short B_sz, thread ValT (&vals)[N_PER_THREAD], thread IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; short a_idx = 0; short b_idx = 0; for (int i = 0; i < N_PER_THREAD; ++i) { auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init); auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init); bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; if (ARG_SORT) { if (pred) { idxs[i] = Bs_idx[b_idx]; } else { idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0); } } b_idx += short(pred); a_idx += short(!pred); } } static METAL_FUNC void sort( threadgroup ValT* tgp_vals [[threadgroup(0)]], threadgroup IdxT* tgp_idxs [[threadgroup(1)]], int size_sorted_axis, uint3 lid [[thread_position_in_threadgroup]]) { // Get thread location int idx = lid.x * N_PER_THREAD; // Load from shared memory thread ValT thread_vals[N_PER_THREAD]; thread IdxT thread_idxs[N_PER_THREAD]; for (int i = 0; i < N_PER_THREAD; ++i) { thread_vals[i] = tgp_vals[idx + i]; if (ARG_SORT) { thread_idxs[i] = tgp_idxs[idx + i]; } } // Per thread sort if (idx < size_sorted_axis) { thread_sort_t::sort(thread_vals, thread_idxs); } // Do merges using threadgroup memory for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) { // Update threadgroup memory threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; if (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } threadgroup_barrier(mem_flags::mem_threadgroup); // Find location in merge step int merge_group = lid.x / merge_threads; int merge_lane = lid.x % merge_threads; int sort_sz = N_PER_THREAD * merge_threads; int sort_st = N_PER_THREAD * merge_threads * merge_group; // As = tgp_vals[A_st:A_ed] is sorted // Bs = tgp_vals[B_st:B_ed] is sorted int A_st = sort_st; int A_ed = sort_st + sort_sz / 2; int B_st = sort_st + sort_sz / 2; int B_ed = sort_st + sort_sz; const threadgroup ValT* As = tgp_vals + A_st; const threadgroup ValT* Bs = tgp_vals + B_st; int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; // Find a partition of merge elements // Ci = merge(As[partition:], Bs[sort_md - partition:]) // of size N_PER_THREAD for each merge lane i // C = [Ci] is sorted int sort_md = N_PER_THREAD * merge_lane; int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); As += partition; Bs += sort_md - partition; A_sz -= partition; B_sz -= sort_md - partition; const threadgroup IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; const threadgroup IdxT* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; // Merge starting at the partition and store results in thread registers merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); } // Write out to shared memory threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; if (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } } }; /////////////////////////////////////////////////////////////////////////////// // Kernel sort /////////////////////////////////////////////////////////////////////////////// template < typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan> struct KernelMergeSort { using ValT = T; using IdxT = uint; using block_merge_sort_t = BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; static METAL_FUNC void block_sort( const device T* inp, device U* out, const constant int& size_sorted_axis, const constant int& in_stride_sorted_axis, const constant int& out_stride_sorted_axis, const constant int& in_stride_segment_axis, const constant int& out_stride_segment_axis, threadgroup ValT* tgp_vals, threadgroup IdxT* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // tid.y tells us the segment index inp += tid.y * in_stride_segment_axis; out += tid.y * out_stride_segment_axis; // Copy into threadgroup memory for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] : ValT(CompareOp::init); if (ARG_SORT) { tgp_idxs[i] = i; } } // Sort elements within the block threadgroup_barrier(mem_flags::mem_threadgroup); block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); threadgroup_barrier(mem_flags::mem_threadgroup); // Write output for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { if (ARG_SORT) { out[i * out_stride_sorted_axis] = tgp_idxs[i]; } else { out[i * out_stride_sorted_axis] = tgp_vals[i]; } } } }; template < typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( const device T* inp [[buffer(0)]], device U* out [[buffer(1)]], const constant int& size_sorted_axis [[buffer(2)]], const constant int& in_stride_sorted_axis [[buffer(3)]], const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& in_stride_segment_axis [[buffer(5)]], const constant int& out_stride_segment_axis [[buffer(6)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMergeSort; using ValT = typename sort_kernel::ValT; using IdxT = typename sort_kernel::IdxT; if (ARG_SORT) { threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis, tgp_vals, tgp_idxs, tid, lid); } else { threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis, tgp_vals, nullptr, tid, lid); } } constant constexpr const int zero_helper = 0; template < typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( const device T* inp [[buffer(0)]], device U* out [[buffer(1)]], const constant int& size_sorted_axis [[buffer(2)]], const constant int& in_stride_sorted_axis [[buffer(3)]], const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], const constant int* nc_shape [[buffer(6)]], const constant int64_t* in_nc_strides [[buffer(7)]], const constant int64_t* out_nc_strides [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMergeSort; using ValT = typename sort_kernel::ValT; using IdxT = typename sort_kernel::IdxT; auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); inp += in_block_idx; out += out_block_idx; if (ARG_SORT) { threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, zero_helper, zero_helper, tgp_vals, tgp_idxs, tid, lid); } else { threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, zero_helper, zero_helper, tgp_vals, nullptr, tid, lid); } } template < typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan> struct KernelMultiBlockMergeSort { using block_merge_sort_t = BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; static METAL_FUNC void block_sort( const device ValT* inp, device ValT* out_vals, device IdxT* out_idxs, const constant int& size_sorted_axis, const constant int& stride_sorted_axis, threadgroup ValT* tgp_vals, threadgroup IdxT* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // tid.y tells us the segment index int base_idx = tid.x * N_PER_BLOCK; // Copy into threadgroup memory for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : ValT(CompareOp::init); tgp_idxs[i] = idx; } // Sort elements within the block threadgroup_barrier(mem_flags::mem_threadgroup); block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); threadgroup_barrier(mem_flags::mem_threadgroup); // Write output for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; if (idx < size_sorted_axis) { out_vals[idx] = tgp_vals[i]; out_idxs[idx] = tgp_idxs[i]; } } } static METAL_FUNC int merge_partition( const device ValT* As, const device ValT* Bs, int A_sz, int B_sz, int sort_md) { CompareOp op; int A_st = max(0, sort_md - B_sz); int A_ed = min(sort_md, A_sz); while (A_st < A_ed) { int md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; if (op(b, a)) { A_ed = md; } else { A_st = md + 1; } } return A_ed; } }; template < typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( const device ValT* inp [[buffer(0)]], device ValT* out_vals [[buffer(1)]], device IdxT* out_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], const constant int* nc_shape [[buffer(6)]], const constant int64_t* nc_strides [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); inp += block_idx; out_vals += tid.y * size_sorted_axis; out_idxs += tid.y * size_sorted_axis; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out_vals, out_idxs, size_sorted_axis, stride_sorted_axis, tgp_vals, tgp_idxs, tid, lid); } template < typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel]] void mb_block_partition( device IdxT* block_partitions [[buffer(0)]], const device ValT* dev_vals [[buffer(1)]], const device IdxT* dev_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& merge_tiles [[buffer(4)]], const constant int& n_blocks [[buffer(5)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_dims [[threads_per_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; block_partitions += tid.y * tgp_dims.x; dev_vals += tid.y * size_sorted_axis; dev_idxs += tid.y * size_sorted_axis; for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { // Find location in merge step int merge_group = i / merge_tiles; int merge_lane = i % merge_tiles; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; int A_st = min(size_sorted_axis, sort_st); int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); int B_st = A_ed; int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); int partition = sort_kernel::merge_partition( dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); block_partitions[i] = A_st + partition; } } template < typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge( const device IdxT* block_partitions [[buffer(0)]], const device ValT* dev_vals_in [[buffer(1)]], const device IdxT* dev_idxs_in [[buffer(2)]], device ValT* dev_vals_out [[buffer(3)]], device IdxT* dev_idxs_out [[buffer(4)]], const constant int& size_sorted_axis [[buffer(5)]], const constant int& merge_tiles [[buffer(6)]], const constant int& num_tiles [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp>; using block_sort_t = typename sort_kernel::block_merge_sort_t; block_partitions += tid.y * (num_tiles + 1); dev_vals_in += tid.y * size_sorted_axis; dev_idxs_in += tid.y * size_sorted_axis; dev_vals_out += tid.y * size_sorted_axis; dev_idxs_out += tid.y * size_sorted_axis; int block_idx = tid.x; int merge_group = block_idx / merge_tiles; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; int A_st = block_partitions[block_idx + 0]; int A_ed = block_partitions[block_idx + 1]; int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); int B_ed = min( size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); if ((block_idx % merge_tiles) == merge_tiles - 1) { A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); B_ed = min(size_sorted_axis, sort_st + sort_sz); } int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; // Load from global memory thread ValT thread_vals[N_PER_THREAD]; thread IdxT thread_idxs[N_PER_THREAD]; for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; if (idx < (A_sz + B_sz)) { thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz]; thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz]; } else { thread_vals[i] = CompareOp::init; thread_idxs[i] = 0; } } // Write to shared memory threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; tgp_vals[idx] = thread_vals[i]; tgp_idxs[idx] = thread_idxs[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Merge int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); int A_st_local = block_sort_t::merge_partition( tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); int A_ed_local = A_sz; int B_st_local = sort_md_local - A_st_local; int B_ed_local = B_sz; int A_sz_local = A_ed_local - A_st_local; int B_sz_local = B_ed_local - B_st_local; // Do merge block_sort_t::merge_step( tgp_vals + A_st_local, tgp_vals + A_ed_local + B_st_local, tgp_idxs + A_st_local, tgp_idxs + A_ed_local + B_st_local, A_sz_local, B_sz_local, thread_vals, thread_idxs); threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N_PER_THREAD; ++i) { int idx = lid.x * N_PER_THREAD; tgp_vals[idx + i] = thread_vals[i]; tgp_idxs[idx + i] = thread_idxs[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write output int base_idx = tid.x * sort_kernel::N_PER_BLOCK; for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; if (idx < size_sorted_axis) { dev_vals_out[idx] = tgp_vals[i]; dev_idxs_out[idx] = tgp_idxs[i]; } } } ================================================ FILE: mlx/backend/metal/kernels/sort.metal ================================================ // Copyright © 2023-2024 Apple Inc. #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/sort.h" #define instantiate_block_sort( \ name, itname, itype, otname, otype, arg_sort, bn, tn) \ instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ block_sort, itype, otype, arg_sort, bn, tn) \ instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ block_sort_nc, itype, otype, arg_sort, bn, tn) #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ instantiate_block_sort( \ arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn) #define instantiate_block_sort_base(itname, itype, bn, tn) \ instantiate_block_sort( \ _block_sort, itname, itype, itname, itype, false, bn, tn) #define instantiate_block_sort_tn(itname, itype, bn) \ instantiate_block_sort_base(itname, itype, bn, 4) \ instantiate_arg_block_sort_base(itname, itype, bn, 4) #define instantiate_block_sort_bn(itname, itype) \ instantiate_block_sort_tn(itname, itype, 32) \ instantiate_block_sort_tn(itname, itype, 64) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) \ instantiate_block_sort_tn(itname, itype, 512) instantiate_block_sort_bn(uint8, uint8_t) instantiate_block_sort_bn(uint16, uint16_t) instantiate_block_sort_bn(uint32, uint32_t) instantiate_block_sort_bn(int8, int8_t) instantiate_block_sort_bn(int16, int16_t) instantiate_block_sort_bn(int32, int32_t) instantiate_block_sort_bn(float16, half) instantiate_block_sort_bn(float32, float) instantiate_block_sort_bn(bfloat16, bfloat16_t) #define instantiate_block_sort_long(itname, itype) \ instantiate_block_sort_tn(itname, itype, 32) \ instantiate_block_sort_tn(itname, itype, 64) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) instantiate_block_sort_long(uint64, uint64_t) instantiate_block_sort_long(int64, int64_t) #define instantiate_multi_block_sort( \ vtname, vtype, itname, itype, arg_sort, bn, tn) \ instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ mb_block_sort, vtype, itype, arg_sort, bn, tn) \ instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ mb_block_partition, vtype, itype, arg_sort, bn, tn) \ instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ mb_block_merge, vtype, itype, arg_sort, bn, tn) #define instantiate_multi_block_sort_base(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 4) instantiate_multi_block_sort_base(uint8, uint8_t) instantiate_multi_block_sort_base(uint16, uint16_t) instantiate_multi_block_sort_base(uint32, uint32_t) instantiate_multi_block_sort_base(int8, int8_t) instantiate_multi_block_sort_base(int16, int16_t) instantiate_multi_block_sort_base(int32, int32_t) instantiate_multi_block_sort_base(float16, half) instantiate_multi_block_sort_base(float32, float) instantiate_multi_block_sort_base(bfloat16, bfloat16_t) #define instantiate_multi_block_sort_long(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4) instantiate_multi_block_sort_long(uint64, uint64_t) instantiate_multi_block_sort_long(int64, int64_t) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/attn/attn.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/attn/loader.h" #include "mlx/backend/metal/kernels/steel/attn/mma.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/kernels/steel/attn/transforms.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// // GEMM kernel class /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct LoopAlignment {}; template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, typename AccumType = typename AccumHelper::accum_type, typename Epilogue = TransformNone> struct GEMMKernel { STEEL_CONST short tgp_padding_a = 16 / sizeof(T); STEEL_CONST short tgp_padding_b = 16 / sizeof(T); STEEL_CONST short tgp_mem_size_a = transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); STEEL_CONST short tgp_mem_size_b = transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; STEEL_CONST short tgp_size = WM * WN * 32; using loader_a_t = BlockLoader< T, transpose_a ? BK : BM, transpose_a ? BM : BK, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, !transpose_a, tgp_size>; using loader_b_t = BlockLoader< T, transpose_b ? BN : BK, transpose_b ? BK : BN, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, transpose_b, tgp_size>; using mma_t = BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, AccumType, Epilogue>; /* Main kernel function */ template static METAL_FUNC void gemm_loop( threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], const int gemm_k_iterations, thread loader_a_t& loader_a, thread loader_b_t& loader_b, thread mma_t& mma_op, thread const short& tgp_bm, thread const short& tgp_bn, thread const short& lbk, LoopAlignment l = {}) { // Appease the compiler (void)l; short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe(tile_dims_B); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } if (!K_aligned_) { threadgroup_barrier(mem_flags::mem_threadgroup); short2 tile_dims_A_last = transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); short2 tile_dims_B_last = transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); loader_a.load_safe(tile_dims_A_last); loader_b.load_safe(tile_dims_B_last); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } } /* Main kernel function */ static METAL_FUNC void run( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device U* D [[buffer(2)]], const constant GEMMParams* params [[buffer(3)]], threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // Pacifying compiler (void)lid; const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); int gemm_k_iterations = params->gemm_k_iterations_aligned; /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Loop tail if (!K_aligned) { int lbk = params->K - params->gemm_k_iterations_aligned * BK; short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } // Store results to device memory mma_op.store_result(D, params->ldd); return; } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { // Loop over K - unaligned case short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; if (tgp_bm == BM && tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result(D, params->ldd); return; } else if (tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else if (tgp_bm == BM) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } } } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h ================================================ // Copyright © 2024-25 Apple Inc. #include "mlx/backend/metal/kernels/steel/attn/attn.h" using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// constant bool align_Q [[function_constant(200)]]; constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; struct MaxOp { template METAL_FUNC static constexpr T apply(T x, T y) { return metal::max(x, y); } }; struct SumOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x + y; } }; struct MulOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x * y; } }; struct SubOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x - y; } }; struct ExpSubOp { template METAL_FUNC static constexpr T apply(T x, T y) { return fast::exp2(x - y); } }; struct DivOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x / y; } }; // clang-format off template < typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( const device T* Q [[buffer(0)]], const device T* K [[buffer(1)]], const device T* V [[buffer(2)]], device T* O [[buffer(3)]], const constant AttnParams* params [[buffer(4)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device T* sinks [[buffer(7), function_constant(has_sinks)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on // Pacifying compiler (void)lid; // Move to correct block ulong3 tidl{tid.x, tid.y, tid.z}; Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch kv_head_idx * params->K_strides[1]; // Head V += tidl.z * params->V_strides[0] + // Batch kv_head_idx * params->V_strides[1]; // Head O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch tidl.y * mask_params->M_strides[1]; // Head } // Prepare threadgroup memory constexpr short padQ = 16 / sizeof(T); constexpr short padK = 16 / sizeof(T); constexpr short padV = 16 / sizeof(T); constexpr short LDQ_tgp = BD + padQ; constexpr short LDK_tgp = BK + padK; constexpr short LDV_tgp = BD + padV; constexpr short tgp_mem_0 = (BK + padK) * (BD); constexpr short tgp_mem_1 = BK * (BD + padV); constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; threadgroup T Q_smem[BQ * (BD + padQ)]; threadgroup T KV_smem[tgp_mem_s]; threadgroup T* Qs = Q_smem; threadgroup T* Ks = KV_smem; threadgroup T* Vs = KV_smem; // Prepare block loaders using QBlockLoader = BlockLoaderT< /* typename T = */ T, /* short BROWS = */ BQ, /* short BCOLS = */ BD, /* short kDstStrRow = */ LDQ_tgp, /* short kDstStrCol = */ 1, /* short reduction_dim = */ 1, /* short tgp_size = */ WM * WN * 32>; // K is loaded in transposed using KBlockLoader = BlockLoaderT< /* typename T = */ T, /* short BROWS = */ BK, /* short BCOLS = */ BD, /* short kDstStrRow = */ 1, /* short kDstStrCol = */ LDK_tgp, /* short reduction_dim = */ 0, /* short tgp_size = */ WM * WN * 32>; using VBlockLoader = BlockLoaderT< /* typename T = */ T, /* short BROWS = */ BK, /* short BCOLS = */ BD, /* short kDstStrRow = */ LDV_tgp, /* short kDstStrCol = */ 1, /* short reduction_dim = */ 0, /* short tgp_size = */ WM * WN * 32>; QBlockLoader loader_q( Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); KBlockLoader loader_k( K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); const AccumType scale = params->scale * M_LOG2E_F; // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size using MMAFrag_acc_t = BaseMMAFrag; constexpr int kNWarps = WM * WN; static_assert( BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); // Q seq frags per warp constexpr int TQ = BQ / (kNWarps * kFragSize); // KV sequence frags (all warps load the same frags) constexpr int TK = BK / kFragSize; // HeadDim frags (all warps load the same frags) constexpr int TD = BD / kFragSize; static_assert(TQ == 1, "Check TQ"); MMATile Qtile; MMATile Ktile; MMATile Stile; MMATile Vtile; MMATile Otile; Otile.clear(); // Prepare mma tile offsets const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); const short sm = simd_coord.y; const short sn = simd_coord.x; const short tm = kFragSize * TQ * simd_group_id; const short Qs_offset = (tm + sm) * LDQ_tgp + sn; const short Ks_offset = sm * LDK_tgp + sn; const short Vs_offset = sm * LDV_tgp + sn; constexpr short Qs_tile_stride = kFragSize; constexpr short Ks_tile_stride = kFragSize * LDK_tgp; threadgroup_barrier(mem_flags::mem_threadgroup); // Load Q blocks if (!align_Q && int(tid.x) == (params->NQ_aligned)) { loader_q.load_safe(short2(BD, params->qL_rem)); } else { loader_q.load_unsafe(); } // Init row reduction variables constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; AccumType max_score[kRowsPT]; AccumType sum_score[kRowsPT] = {0}; // Init to -Inf STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = Limits::finite_min; } if (has_sinks) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); sum_score[i] = 1; } } int kb_lim = params->NK; int kb_min_causal = params->NK; if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; kb_lim = min(params->NK, kb_lim); int q_min = tid.x * BQ + params->qL_off; q_min = max(0, q_min); kb_min_causal = (q_min / BK); } // Loop over KV seq length for (int kb = 0; kb < kb_lim; kb++) { // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); if (!align_K && kb == (params->NK_aligned)) { loader_k.load_safe(short2(BD, params->kL_rem)); } else { loader_k.load_unsafe(); } // Do S = Q @ K.T Stile.clear(); threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_UNROLL for (short dd = 0; dd < TD; dd++) { simdgroup_barrier(mem_flags::mem_none); Qtile.template load( &Qs[Qs_offset + dd * Qs_tile_stride]); Ktile.template load( &Ks[Ks_offset + dd * Ks_tile_stride]); simdgroup_barrier(mem_flags::mem_none); tile_matmad(Stile, Qtile, Ktile, Stile); } // Apply scale in float32 STEEL_PRAGMA_UNROLL for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { Stile.elems()[ii] *= scale; } // Mask out length sequence if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < stile_t::kTileCols; j++) { short col_pos = sn + (j * stile_t::kFragCols); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { if ((col_pos + jj) >= params->kL_rem) { Stile.frag_at(i, j)[jj] = neg_inf; } } } } } // Mask out if causal if (do_causal && kb >= kb_min_causal) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { const int row_pos = tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); STEEL_PRAGMA_UNROLL for (short j = 0; j < stile_t::kTileCols; j++) { const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { if (row_pos < (col_pos + jj)) { Stile.frag_at(i, j)[jj] = neg_inf; } } } } } // Other masking as needed if (has_mask) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = Limits::finite_min; constexpr bool is_bool = is_same_v; using melem_t = typename metal::conditional_t; using MMAFrag_mask_t = BaseMMAFrag; using frag_t = typename MMAFrag_mask_t::frag_type; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); STEEL_PRAGMA_UNROLL for (short j = 0; j < stile_t::kTileCols; j++) { const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); frag_t mfrag; MMAFrag_mask_t::load_safe( mfrag, mask, int64_t(mask_params->M_strides[2]), Int<1>{}, params->qL, params->kL, row_pos, col_pos); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { if constexpr (is_bool) { Stile.frag_at(i, j)[jj] = mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; } else { Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); } } } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Load V blocks if (!align_K && kb == (params->NK_aligned)) { loader_v.load_safe(short2(BD, params->kL_rem)); } else { loader_v.load_unsafe(); } // Do softmax // Temp variables AccumType new_max[kRowsPT]; AccumType factor[kRowsPT]; STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { new_max[i] = max_score[i]; } // Row max Stile.template row_reduce(new_max); // exp(Si - rowmax(Si)) Stile.template row_bin_op(new_max); // Factor exp(rowmax(Si) - rowmax(Si-1)) STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { factor[i] = fast::exp2(max_score[i] - new_max[i]); } // Save max for next iteration STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = new_max[i]; } // Row Sum AccumType sum_score_tmp[kRowsPT] = {0}; Stile.template row_reduce(sum_score_tmp); // Update norm STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; } // Update O Otile.template row_bin_op(factor); // Load V into registers threadgroup_barrier(mem_flags::mem_threadgroup); STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short id = 0; id < TD; id++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { if constexpr (BD == 128) { simdgroup_barrier(mem_flags::mem_none); } const short kk = ik * kFragSize; const short dd = id * kFragSize; Vtile.template load( &Vs[Vs_offset + kk * LDV_tgp + dd]); if constexpr (BD == 128) { simdgroup_barrier(mem_flags::mem_none); } MMAFrag_acc_t::mma( Otile.frag_at(iq, id), Stile.frag_at(iq, ik), Vtile.frag_at(0, 0), Otile.frag_at(iq, id)); } } } // Prepare for next iteration loader_k.next(); loader_v.next(); } // Normalize output Otile.template row_bin_op(sum_score); threadgroup_barrier(mem_flags::mem_none); // Store results O += (tm + sm) * params->O_strides[2] + sn; if (!align_Q && int(tid.x) == (params->NQ_aligned)) { auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); } else { Otile.template store(O, params->O_strides[2]); } } ================================================ FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal ================================================ // Copyright © 2024-25 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" #define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ instantiate_kernel( \ "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ "_wm" #wm "_wn" #wn "_mask" #mname, \ attention, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) #define instantiate_attn_mask_helper(iname, itype) \ instantiate_attn_shapes_helper(iname, itype, iname, itype) \ instantiate_attn_shapes_helper(iname, itype, bool_, bool) instantiate_attn_mask_helper(float16, half); instantiate_attn_mask_helper(bfloat16, bfloat16_t); instantiate_attn_mask_helper(float32, float); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h ================================================ // Copyright © 2024-25 Apple Inc. #include "mlx/backend/metal/kernels/steel/attn/nax.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/kernels/steel/attn/transforms.h" #include "mlx/backend/metal/kernels/steel/utils.h" using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// constant bool align_Q [[function_constant(200)]]; constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; template struct TransformScale { T scale; METAL_FUNC TransformScale(T scale_) : scale(scale_) {} METAL_FUNC T apply(T x) const { return scale * x; } }; struct MaxOp { template METAL_FUNC static constexpr T apply(T x, T y) { return metal::max(x, y); } }; struct SumOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x + y; } }; struct MulOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x * y; } }; struct SubOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x - y; } }; struct ExpSubOp { template METAL_FUNC static constexpr T apply(T x, T y) { return fast::exp2(x - y); } }; struct DivOp { template METAL_FUNC static constexpr T apply(T x, T y) { return x / y; } }; // clang-format off template < typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( const device T* Q [[buffer(0)]], const device T* K [[buffer(1)]], const device T* V [[buffer(2)]], device T* O [[buffer(3)]], const constant AttnParams* params [[buffer(4)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device T* sinks [[buffer(7), function_constant(has_sinks)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on // Pacifying compiler (void)lid; (void)simd_lane_id; // Move to correct block ulong3 tidl{tid.x, tid.y, tid.z}; Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch kv_head_idx * params->K_strides[1]; // Head V += tidl.z * params->V_strides[0] + // Batch kv_head_idx * params->V_strides[1]; // Head O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch tidl.y * mask_params->M_strides[1]; // Head } const metal::uniform scale2 = make_uniform(params->scale) * make_uniform(1.44269504089f); // Prepare MMA tiles constexpr short kU = 16; constexpr int kNWarps = WM * WN; static_assert( BQ >= (kNWarps * kU) && BQ % (kNWarps * kU) == 0, "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); // Q seq frags per warp constexpr int TQ = BQ / (kNWarps * kU); // HeadDim frags (all warps load the same frags) constexpr int TD = BD / kU; // KV seq frags per warp constexpr short TK = BK / kU; static_assert(TQ == 1, "Check TQ"); using otile_t = NAXTile; otile_t Otile; Otile.clear(); // Prepare mma tile offsets const short tm = kU * TQ * simd_group_id; Q += tm * int(params->Q_strides[2]); const short2 simd_coord = otile_t::NAXFrag_t::get_coord(); const short sm = simd_coord.y; const short sn = simd_coord.x; // Init row reduction variables constexpr short kRowsPT = otile_t::kRowsPerThread; metal::vec max_score; metal::vec sum_score{0}; // Init to -Inf STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = Limits::finite_min; } if (has_sinks) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); sum_score[i] = 1; } } int kb_lim = params->NK; int kb_min_causal = params->NK; if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; kb_lim = min(params->NK, kb_lim); int q_min = tid.x * BQ + params->qL_off; q_min = max(0, q_min); kb_min_causal = (q_min / BK); } const bool is_last_bq = int(tid.x) == (params->NQ_aligned); // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); const bool is_last_q = is_last_bq; const short lim_rows_q = params->qL_rem - tm; const short lim_rows_k = params->kL_rem; // Loop over KV seq length for (int kb = 0; kb < kb_lim; kb++) { const int is_last_k = (kb == (params->NK_aligned)); // Do S = Q @ K.T using stile_t = NAXTile; stile_t Stile; Stile.clear(); STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik += 2) { STEEL_PRAGMA_UNROLL for (short id = 0; id < TD; id++) { NAXTile Qtile; NAXTile Ktile; const int Q_load_off = iq * kU * int(params->Q_strides[2]) + id * kU; const int K_load_off = ik * kU * int(params->K_strides[2]) + id * kU; if (!align_Q && is_last_q) { Qtile.load_rows( Q + Q_load_off, int(params->Q_strides[2]), lim_rows_q - iq * kU); } else { Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); } if (!align_K && is_last_k) { Ktile.load_rows( K + K_load_off, int(params->K_strides[2]), lim_rows_k - ik * kU); } else { Ktile.load(K + K_load_off, int(params->K_strides[2])); } stile_t::NAXFrag_t::mma( Stile.frag_at(iq, ik), Stile.frag_at(iq, ik + 1), Qtile.frag_at(0, 0), metal::false_type{}, Ktile.frag_at(0, 0), Ktile.frag_at(1, 0), metal::true_type{}); } } } // Scale S STEEL_PRAGMA_UNROLL for (short ii = 0; ii < stile_t::kElemsPerTile; ii++) { Stile.elems()[ii] *= float(scale2); } // Mask out length sequence if (!align_K && is_last_k) { constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { const short col_pos = ik * kU + sn; thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL for (short ii = 0; ii < stile_t::kFragThrRows; ii++) { STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { const auto loc = ii * stile_t::kFragThrCols + jj; fg[loc] = ((col_pos + jj) < params->kL_rem) ? fg[loc] : neg_inf; } } } } } // Mask out if causal if (do_causal && kb >= kb_min_causal) { constexpr auto neg_inf = Limits::finite_min; const int base_row = tid.x * BQ + params->qL_off + tm; const int base_col = kb * BK; STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { const short row_pos = base_row + iq * kU; const short col_pos = base_col + ik * kU; thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL for (short ii = 0; ii < stile_t::kFragThrRows; ii++) { STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { const auto r = row_pos + ii * stile_t::kFragRowsJump + sm; const auto c = col_pos + jj + sn; const auto loc = ii * stile_t::kFragThrCols + jj; fg[loc] = (r < c) ? neg_inf : fg[loc]; } } } } } // Other masking as needed if (has_mask) { constexpr auto neg_inf = Limits::finite_min; const int base_row = tid.x * BQ + tm; const int base_col = kb * BK; constexpr bool is_bool = is_same_v; using melem_t = typename metal::conditional_t; using mtile_t = NAXTile; using mfrag_t = typename mtile_t::frag_type; STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { const short row_pos = base_row + iq * kU; const short col_pos = base_col + ik * kU; mfrag_t mfrag; mtile_t::NAXFrag_t::load_safe( mfrag, mask, int64_t(mask_params->M_strides[2]), Int<1>{}, params->qL, params->kL, row_pos, col_pos); thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < mtile_t::kElemsPerFrag; jj++) { if constexpr (is_bool) { fg[jj] = mfrag[jj] ? fg[jj] : neg_inf; } else { fg[jj] += M_LOG2E_F * AccumType(mfrag[jj]); } } } } } // Do softmax // Temp variables metal::vec new_max; metal::vec factor; STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { new_max[i] = max_score[i]; } // Row max Stile.template row_reduce(new_max); // exp(Si - rowmax(Si)) Stile.template row_bin_op(new_max); // Factor exp(rowmax(Si) - rowmax(Si-1)) STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { factor[i] = fast::exp2(max_score[i] - new_max[i]); max_score[i] = new_max[i]; } // Row Sum STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { sum_score[i] = sum_score[i] * factor[i]; } Stile.template row_reduce(sum_score); // Update O Otile.template row_bin_op(factor); simdgroup_barrier(mem_flags::mem_none); // Do O = P @ V STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short id = 0; id < TD; id += 2) { if constexpr (BD == 128) { if (id == 4) { threadgroup_barrier(mem_flags::mem_none); } } STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { NAXTile Vtile; const int V_load_off = ik * kU * int(params->V_strides[2]) + id * kU; if (!align_K && is_last_k) { Vtile.load_rows( V + V_load_off, int(params->V_strides[2]), lim_rows_k - ik * kU); } else { Vtile.load(V + V_load_off, int(params->V_strides[2])); } otile_t::NAXFrag_t::mma( Otile.frag_at(iq, id), Otile.frag_at(iq, id + 1), Stile.frag_at(iq, ik), metal::false_type{}, Vtile.frag_at(0, 0), Vtile.frag_at(0, 1), metal::false_type{}); } } } // Prepare for next iteration K += BK * int(params->K_strides[2]); V += BK * int(params->V_strides[2]); } // Normalize output threadgroup_barrier(mem_flags::mem_none); metal::vec rcp; STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { rcp[i] = 1.f / sum_score[i]; } Otile.template row_bin_op(rcp); // Store results O += tm * int(params->O_strides[2]); if (!align_Q && is_last_q) { if (lim_rows_q <= 0) return; Otile.store_rows(O, int(params->O_strides[2]), lim_rows_q); } else { Otile.store(O, int(params->O_strides[2])); } } ================================================ FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal ================================================ // Copyright © 2024-25 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h" #define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ instantiate_kernel( \ "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ "_wm" #wm "_wn" #wn "_mask" #mname, \ attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype) #define instantiate_attn_mask_helper(iname, itype) \ instantiate_attn_shapes_helper(iname, itype, iname, itype) \ instantiate_attn_shapes_helper(iname, itype, bool_, bool) instantiate_attn_mask_helper(float16, half); instantiate_attn_mask_helper(bfloat16, bfloat); instantiate_attn_mask_helper(float32, float); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/attn/loader.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short alignment = 1, short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; // Leading dimension for src const int src_ld; const int tile_stride; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; struct alignas(alignment * sizeof(T)) ReadVector { uint8_t v[sizeof(T) * vec_size]; }; /* Constructor */ METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} /* Apply operation to threadgroup without bound checking */ template METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { *((threadgroup ReadVector*)(&dst[i * dst_ld])) = *((const device ReadVector*)(&src[i * src_ld])); } } /* Load from device memory into threadgroup memory - with bound checking */ METAL_FUNC void load_safe(short2 src_tile_dim) const { src_tile_dim = src_tile_dim - short2(bj, bi); // Skip loading if thread has no valid reads if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } return; } // Use fast thread memory for bound checks bool tmp_idx[vec_size]; T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { // Make sure tmp_idx only contains valid indices STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); } // Read valid indices into tmp_val STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); } // Copy values to threadgroup memory STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = tmp_val[j]; } } } /* Iteration helper */ METAL_FUNC void next() { src += tile_stride; } }; template struct CShape { STEEL_CONST int kRows = R; STEEL_CONST int kCols = C; }; template < typename T, short BROWS, short BCOLS, short kDstStrRow, short kDstStrCol, short reduction_dim, short tgp_size, short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> struct BlockLoaderT { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; // Leading dimension for src const int src_ld; const int tile_stride; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; /* Constructor */ METAL_FUNC BlockLoaderT( const device T* src_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), src(src_ + bi * src_ld + bj) {} /* Apply operation to threadgroup without bound checking */ template METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * kDstStrRow + j * kDstStrCol] = op.apply(dst[i * kDstStrRow + j * kDstStrCol]); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; } } } /* Load from device memory into threadgroup memory - with bound checking */ METAL_FUNC void load_safe(short2 src_tile_dim) const { src_tile_dim = src_tile_dim - short2(bj, bi); // Skip loading if thread has no valid reads if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * kDstStrRow + j * kDstStrCol] = T(0); } } return; } // Use fast thread memory for bound checks bool tmp_idx[vec_size]; T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { // Make sure tmp_idx only contains valid indices STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); } // Read valid indices into tmp_val STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); } // Copy values to threadgroup memory STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; } } } /* Iteration helper */ METAL_FUNC void next() { src += tile_stride; } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/attn/mma.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/backend/metal/kernels/steel/attn/transforms.h" #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct Shape2D { RInt r; CInt c; Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} }; template struct Layout2D { Shape shape; Layout layout; }; template struct BaseMMAFrag { static_assert( kFragRows_ == 8, "Only 8 x 8 fragment matrices are currently supported"); static_assert( kFragCols_ == 8, "Only 8 x 8 fragment matrices are currently supported"); }; template struct BaseMMAFrag { STEEL_CONST int kFragRows = 8; STEEL_CONST int kFragCols = 8; STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; STEEL_CONST int kElemRows = 1; STEEL_CONST int kElemCols = 2; static_assert( kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); typedef metal::simdgroup_matrix mat_type; typedef metal::vec frag_type; typedef metal::vec row_frag_type; typedef metal::vec col_frag_type; template using dtype_mat_t = typename metal::simdgroup_matrix; template using dtype_frag_t = typename metal::vec; METAL_FUNC static constexpr short2 get_coord( ushort simd_lane_id [[thread_index_in_simdgroup]]) { const short qid = simd_lane_id / 4; const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; return short2{fn, fm}; } template METAL_FUNC static constexpr void load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); } } } template < typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX, typename OffY> METAL_FUNC static constexpr void load_safe( thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { src += off_x * str_x + off_y * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[i * kElemCols + j] = static_cast(src[0]); } else { dst[i * kElemCols + j] = T(0); } src += str_y; } src -= kElemCols * str_y; src += str_x; } } template METAL_FUNC static constexpr void store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { using U = pointer_element_t; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); } } } template < typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX, typename OffY> METAL_FUNC static constexpr void store_safe( const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[(off_x + i) * str_x + (off_y + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template METAL_FUNC static constexpr void mma( thread frag_type& D, thread dtype_frag_t& A, thread dtype_frag_t& B, thread dtype_frag_t& C) { mat_type D_mat; dtype_mat_t A_mat; dtype_mat_t B_mat; dtype_mat_t C_mat; reinterpret_cast&>(A_mat.thread_elements()) = A; reinterpret_cast&>(B_mat.thread_elements()) = B; reinterpret_cast&>(C_mat.thread_elements()) = C; mma(D_mat, A_mat, B_mat, C_mat); D = reinterpret_cast(D_mat.thread_elements()); } template METAL_FUNC static constexpr void mma( thread mat_type& D, thread dtype_mat_t& A, thread dtype_mat_t& B, thread dtype_mat_t& C) { simdgroup_multiply_accumulate(D, A, B, C); } template METAL_FUNC static constexpr void row_reduce( thread const frag_type& inp_vals, thread T* reduced_vals) { T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); qgr_reduce = Op::apply(thr_reduce, qgr_reduce); T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); } template METAL_FUNC static constexpr void row_bin_op( thread frag_type& inp_vals, thread T* row_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { inp_vals[i * kElemCols + j] = Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); } } } }; template < typename T, int kTileRows_, int kTileCols_, class MMAFrag_ = BaseMMAFrag> struct MMATile { using MMAFrag_t = MMAFrag_; using elem_type = T; STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; STEEL_CONST int kTileRows = kTileRows_; STEEL_CONST int kTileCols = kTileCols_; STEEL_CONST int kRows = kTileRows * kFragRows; STEEL_CONST int kCols = kTileCols * kFragCols; STEEL_CONST int kNumFrags = kTileRows * kTileCols; STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::frag_type frag_type; frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC MMATile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kNumFrags; ++i) { val_frags[i] = frag_type(0); } } METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { return val_frags[i * kTileCols + j]; } METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { return val_frags[i * kTileCols + j]; } METAL_FUNC mat_type mat_at(const short i, const short j) { mat_type val_mat; STEEL_PRAGMA_UNROLL for (short ii = 0; ii < kElemsPerFrag; ++ii) { val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; } return val_mat; } METAL_FUNC thread elem_type* elems() { return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::template row_reduce( frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); } } } template METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::template row_bin_op( frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); } } } template METAL_FUNC void load(const threadgroup U* src) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::load( frag_at(i, j), &( src[(i * kFragRows) * w_x * str_x + (j * kFragCols) * w_y * str_y]), Int{}, Int{}); } } } template METAL_FUNC void store(threadgroup U* dst) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::store( frag_at(i, j), &( dst[(i * kFragRows) * w_x * str_x + (j * kFragCols) * w_y * str_y]), Int{}, Int{}); } } } template METAL_FUNC void load(const device U* src, const int ld) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::load( frag_at(i, j), &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld, Int<1>{}); } } } template METAL_FUNC void store(device U* dst, const int ld) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::store( frag_at(i, j), &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld, Int<1>{}); } } } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { MMAFrag_t::load_safe( frag_at(i, j), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, (i * kFragRows) * w_x, (j * kFragCols) * w_y); } } } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { MMAFrag_t::store_safe( frag_at(i, j), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, (i * kFragRows) * w_x, (j * kFragCols) * w_y); } } } }; template < typename Dtype, typename Atype, typename Btype, typename Ctype, int M, int N, int K, class MMAFragD, class MMAFragA, class MMAFragB, class MMAFragC> METAL_FUNC void tile_matmad( thread MMATile& D, thread MMATile& A, thread MMATile& B, thread MMATile& C) { STEEL_PRAGMA_UNROLL for (short m = 0; m < M; ++m) { STEEL_PRAGMA_UNROLL for (short n = 0; n < N; ++n) { short m_serp = m; //(n % 2) ? (M - 1 - m) : m; short n_serp = (m % 2) ? (N - 1 - n) : n; STEEL_PRAGMA_UNROLL for (short k = 0; k < K; ++k) { MMAFragD::mma( D.frag_at(m_serp, n_serp), A.frag_at(m_serp, k), B.frag_at(k, n_serp), C.frag_at(m_serp, n_serp)); } } } } template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, short lda_tgp, short ldb_tgp, typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { // MMAFrag size STEEL_CONST short kFragSize = 8; using MMAFrag_acc_t = BaseMMAFrag; // Warp tile simdgroup matrix strides along M STEEL_CONST short TM_stride = kFragSize * WM; // Warp tile simdgroup matrix strides along M STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M STEEL_CONST short TM = BM / TM_stride; // Warp tile size along N STEEL_CONST short TN = BN / TN_stride; // Threadgroup A strides STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K // Threadgroup B strides STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N // Threadgroup strides along K STEEL_CONST short tile_stride_a = kFragSize * A_str_k; STEEL_CONST short tile_stride_b = kFragSize * B_str_k; // Simdgroup matrices MMATile Atile; MMATile Btile; MMATile Ctile; // Offsets within threadgroup short sm; short sn; short As_offset; short Bs_offset; /* Constructor */ METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) { // Determine thread position in simdgroup matrix short tm = kFragSize * (simd_group_id / WN); short tn = kFragSize * (simd_group_id % WN); short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); sm = simd_coord.y; sn = simd_coord.x; // Determine thread and simdgroup offset As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N sm += tm; sn += tn; } /* (BM, BK) X (BK, BN) multiply accumulate function */ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { // Adjust for simdgroup and thread location As += As_offset; Bs += Bs_offset; // Iterate over BK in blocks of kFragSize STEEL_PRAGMA_UNROLL for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); tile_matmad(Ctile, Atile, Btile, Ctile); // Progress to next simdgroup tile As += tile_stride_a; Bs += tile_stride_b; } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result(device U* D, const int ldd) { // Apply epilogue STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } // Adjust for simdgroup and thread location D += sm * ldd + sn; Ctile.template store(D, ldd); } METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { // Apply epilogue STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } // Adjust for simdgroup and thread location D += sm * ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; Ctile.template store_safe(D, ldd, dst_tile_dims); } /* Apply epilogue */ template METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue( const device U* C, const int ldc, const int fdc, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue_safe( const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Read C U c_elems[kelems] = {0}; STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { c_elems[k] = C[offset_c + k * fdc]; } } // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { accum[k] = epilogue_op.apply(accum[k], c_elems[k]); } } } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } METAL_FUNC void store_result_safe( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } } } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/attn/nax.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" #include using namespace metal; /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { /////////////////////////////////////////////////////////////////////////////// // NAX Steel with new tiles /////////////////////////////////////////////////////////////////////////////// struct BaseNAXFrag { STEEL_CONST short kFragRows = 16; STEEL_CONST short kFragCols = 16; STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; STEEL_CONST short kElemRows = 2; STEEL_CONST short kElemCols = 4; STEEL_CONST short kElemRowsJump = 8; static_assert( kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); template using dtype_frag_t = typename metal::vec; METAL_FUNC static short2 get_coord() { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; return short2{fn, fm}; } METAL_FUNC static short2 get_coord(short idx) { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; return short2{fn, fm}; } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_rows( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = T(0); } } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_safe( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; auto ly = lim_y - sc.x; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((r < lx) && ((c + j) < ly)) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_rows( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_safe( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; auto ly = lim_y - sc.x; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if (r < lx && (c + j) < ly) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename StartX, typename StopX, typename StartY, typename StopY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_slice( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, StartX start_x, StopX stop_x, StartY start_y, StopY stop_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; const short2 sc = get_coord(); const_for_loop<0, kElemRows, 1>([&](auto idx_row) { const auto r = off_x + idx_row * Int{}; if (r >= stop_x - sc.y || r < start_x - sc.y) { return; } const_for_loop<0, kElemCols, 1>([&](auto idx_col) { const auto c = off_y + idx_col; if (c >= stop_y - sc.x || c < start_y - sc.x) { return; } const auto src_idx = idx_row * Int{} + idx_col; dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = static_cast(src[src_idx]); }); }); } template METAL_FUNC static constexpr void row_reduce( thread const dtype_frag_t& inp_vals, thread T* reduced_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { T thr_reduce = Op::apply( Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); qgr_reduce = Op::apply(thr_reduce, qgr_reduce); T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); } } template METAL_FUNC static constexpr void row_bin_op( thread dtype_frag_t& inp_vals, thread T* row_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { inp_vals[i * kElemCols + j] = Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); } } } template < typename CType, typename AType, typename BType, bool transpose_a = false, bool transpose_b = false> METAL_FUNC static constexpr void mma( thread dtype_frag_t& Cn0, thread dtype_frag_t& Cn1, const thread dtype_frag_t& A, metal::bool_constant, const thread dtype_frag_t& Bn0, const thread dtype_frag_t& Bn1, metal::bool_constant) { constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( 16, 32, 16, transpose_a, transpose_b, true, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); // Create matmul op mpp::tensor_ops::matmul2d gemm_op; // Create matmul operands in registers auto ct_a = gemm_op .template get_left_input_cooperative_tensor(); auto ct_b = gemm_op .template get_right_input_cooperative_tensor(); // Create matmul output in register auto ct_c = gemm_op.template get_destination_cooperative_tensor< decltype(ct_a), decltype(ct_b), CType>(); // Load A in to left operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_a[i] = A[i]; } // Load B into right operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_b[i] = Bn0[i]; ct_b[kElemsPerFrag + i] = Bn1[i]; } // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_c[i] = Cn0[i]; ct_c[kElemsPerFrag + i] = Cn1[i]; } // Do matmul gemm_op.run(ct_a, ct_b, ct_c); // Copy out results STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { Cn0[i] = ct_c[i]; Cn1[i] = ct_c[kElemsPerFrag + i]; } } template < typename CType, typename AType, typename BType, bool transpose_a = false, bool transpose_b = false> METAL_FUNC static constexpr void mma( thread dtype_frag_t& Cm0, thread dtype_frag_t& Cm1, const thread dtype_frag_t& Am0, const thread dtype_frag_t& Am1, metal::bool_constant, const thread dtype_frag_t& B, metal::bool_constant) { // Create Matmul descriptor constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( 16, 32, 16, transpose_a, transpose_b, true, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); // Create matmul op mpp::tensor_ops::matmul2d gemm_op; // Create matmul operands in registers auto ct_a = gemm_op .template get_left_input_cooperative_tensor(); auto ct_b = gemm_op .template get_right_input_cooperative_tensor(); // Create matmul output in register auto ct_c = gemm_op.template get_destination_cooperative_tensor< decltype(ct_a), decltype(ct_b), CType>(); // Load A in to left operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_a[i] = Am0[i]; ct_a[kElemsPerFrag + i] = Am1[i]; } // Load B into right operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_b[i] = B[i]; } // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_c[i] = Cm0[i]; ct_c[kElemsPerFrag + i] = Cm1[i]; } // Do matmul gemm_op.run(ct_a, ct_b, ct_c); // Copy out results STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { Cm0[i] = ct_c[i]; Cm1[i] = ct_c[kElemsPerFrag + i]; } } }; template < typename T, short kTileRows_, short kTileCols_, class NAXFrag_ = BaseNAXFrag> struct NAXTile { using NAXFrag_t = NAXFrag_; using elem_type = T; STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; STEEL_CONST short kTileRows = kTileRows_; STEEL_CONST short kTileCols = kTileCols_; STEEL_CONST short kRows = kTileRows * kFragRows; STEEL_CONST short kCols = kTileCols * kFragCols; STEEL_CONST short kNumFrags = kTileRows * kTileCols; STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag; STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows; STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols; typedef typename NAXFrag_t::template dtype_frag_t frag_type; frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC NAXTile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kNumFrags; ++i) { val_frags[i] = frag_type(0); } } METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { return val_frags[i * kTileCols + j]; } METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr thread frag_type& frag_at() { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr const thread frag_type& frag_at() const { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j, metal::bool_constant) { if constexpr (transpose) { return frag_at(j, i); } else { return frag_at(i, j); } } template METAL_FUNC constexpr const thread frag_type& frag_at(const short i, const short j, metal::bool_constant) const { if constexpr (transpose) { return frag_at(j, i); } else { return frag_at(i, j); } } template METAL_FUNC constexpr thread frag_type& frag_at() { if constexpr (transpose) { return frag_at(); } else { return frag_at(); } } template METAL_FUNC constexpr const thread frag_type& frag_at() const { if constexpr (transpose) { return frag_at(); } else { return frag_at(); } } METAL_FUNC thread elem_type* elems() { return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { NAXFrag_t::template row_reduce( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { NAXFrag_t::template row_bin_op( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void load(const threadgroup U* src) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load( frag_at(), src, Int{}, Int{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store(threadgroup U* dst) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store( frag_at(), dst, Int{}, Int{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load(const device U* src, const int ld) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load( frag_at(), src, ld, Int<1>{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store(device U* dst, const int ld) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store( frag_at(), dst, ld, Int<1>{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load_rows(const device U* src, const int ld, const short n_rows) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load_rows( frag_at(), src, ld, Int<1>{}, n_rows, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load_safe( frag_at(), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_rows( frag_at(), dst, ld, Int<1>{}, n_rows, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_safe( frag_at(), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_slice( device U* dst, const int ld, const short2 start, const short2 stop) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_slice( frag_at(), dst, ld, Int<1>{}, start.y, stop.y, start.x, stop.x, idx_row * Int{}, idx_col * Int{}); }); }); } }; template < class CTile, class ATile, class BTile, bool transpose_a, bool transpose_b> METAL_FUNC void tile_matmad_nax( thread CTile& C, thread ATile& A, metal::bool_constant, thread BTile& B, metal::bool_constant) { // Static checks constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; constexpr short TM = CTile::kTileRows; static_assert(TMa == TM, "MXU tile matmul: M dimensions do not match"); constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; constexpr short TN = CTile::kTileCols; static_assert(TNb == TN, "MXU tile matmul: N dimensions do not match"); constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows; static_assert(TKa == TK, "MXU tile matmul: K dimensions do not match"); constexpr auto ta = metal::bool_constant{}; constexpr auto tb = metal::bool_constant{}; if constexpr (TN == 1 && TM % 2 == 0) { STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; mm += 2) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; ++nn) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; ++kk) { CTile::NAXFrag_t::mma( C.frag_at(mm, nn), C.frag_at(mm + 1, nn), A.frag_at(mm, kk, ta), A.frag_at(mm + 1, kk, ta), metal::bool_constant{}, B.frag_at(kk, nn, tb), metal::bool_constant{}); } } } } else if constexpr (TN % 2 == 0) { STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; ++mm) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; nn += 2) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; ++kk) { CTile::NAXFrag_t::mma( C.frag_at(mm, nn), C.frag_at(mm, nn + 1), A.frag_at(mm, kk, ta), metal::bool_constant{}, B.frag_at(kk, nn, tb), B.frag_at(kk, nn + 1, tb), metal::bool_constant{}); } } } } } } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/attn/params.h ================================================ // Copyright © 2024 Apple Inc. #pragma once /////////////////////////////////////////////////////////////////////////////// // Attn param classes /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { struct AttnParams { int B; ///< Batch Size int H; ///< Heads int D; ///< Head Dim int qL; ///< Query Sequence Length int kL; ///< Key Sequence Length int gqa_factor; ///< Group Query factor float scale; ///< Attention scale int NQ; ///< Number of query blocks int NK; ///< Number of key/value blocks int NQ_aligned; ///< Number of full query blocks int NK_aligned; ///< Number of full key/value blocks int qL_rem; ///< Remainder in last query block int kL_rem; ///< Remainder in last key/value block int qL_off; ///< Offset in query sequence start int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) }; struct AttnMaskParams { int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/attn/transforms.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/utils.h" /////////////////////////////////////////////////////////////////////////////// // Transforms and Epilogues /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct TransformNone { static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT) { return static_cast(x); } }; template struct TransformAdd { TransformAdd(const float, const float) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT c) { return static_cast(x) + c; } }; template struct TransformAxpby { const float alpha; const float beta; TransformAxpby(const float alpha_, const float beta_) : alpha(alpha_), beta(beta_) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } METAL_FUNC OutT apply(InT x, OutT c) const { return static_cast(x * alpha + (beta * c)); } }; template struct AccumHelper { typedef float accum_type; }; struct BlockSwizzle { static METAL_FUNC int2 swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { const int tid_x = (tid.x) >> swizzle_log; const int tid_y = ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); return int2(tid_x, tid_y); } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/conv/conv.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/conv/loader.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" using namespace metal; using namespace mlx::steel; ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h ================================================ // Copyright © 2024 Apple Inc. #include using namespace metal; template < typename T, int BM, int BN, int BK, int WM, int WN, int N_CHANNELS = 0, bool SMALL_FILTER = false> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* C [[buffer(2)]], const constant MLXConvParams<2>* params [[buffer(3)]], const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using namespace mlx::steel; (void)lid; constexpr bool transpose_a = false; constexpr bool transpose_b = true; constexpr short tgp_padding_a = 16 / sizeof(T); constexpr short tgp_padding_b = 16 / sizeof(T); constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; constexpr short shape_a_rows = (transpose_a ? BK : BM); constexpr short shape_b_rows = (transpose_b ? BN : BK); constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; constexpr short tgp_size = WM * WN * 32; // Input loader using loader_a_t = typename metal::conditional_t< // Check for small channel specialization N_CHANNELS != 0 && N_CHANNELS <= 4, // Go to small channel specialization Conv2DInputBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>, // Else go to general loader typename metal::conditional_t< // Check if filter size is small enough SMALL_FILTER, // Go to small filter specialization Conv2DInputBlockLoaderSmallFilter< T, BM, BN, BK, tgp_size, tgp_padding_a>, // Else go to large filter generalization Conv2DInputBlockLoaderLargeFilter< T, BM, BN, BK, tgp_size, tgp_padding_a>>>; // Weight loader using loader_b_t = typename metal::conditional_t< // Check for small channel specialization N_CHANNELS != 0 && N_CHANNELS <= 4, // Go to small channel specialization Conv2DWeightBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>, // Else go to general loader Conv2DWeightBlockLoader>; using mma_t = BlockMMA< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, shape_a_cols, shape_b_cols>; threadgroup T As[tgp_mem_size_a]; threadgroup T Bs[tgp_mem_size_b]; const int tid_y = ((tid.y) << gemm_params->swizzle_log) + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); const int tid_x = (tid.x) >> gemm_params->swizzle_log; if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { return; } const int c_row = tid_y * BM; const int c_col = tid_x * BN; const int K = gemm_params->K; const int N = gemm_params->N; const int C_per_group = params->C / params->groups; // Groups A += tid.z * C_per_group; B += tid.z * N * K; C += tid.z * N; B += c_col * K; C += c_row * (N * params->groups) + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); // Prepare threadgroup loading operations loader_a_t loader_a( A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); loader_b_t loader_b( B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); int gemm_k_iterations = gemm_params->gemm_k_iterations; for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory short tgp_bm = min(BM, gemm_params->M - c_row); short tgp_bn = min(BN, gemm_params->N - c_col); const int ldc = N * params->groups; mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); } ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal ================================================ // Copyright © 2024 Apple Inc. #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" #define instantiate_implicit_conv_2d( \ name, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ channel_name, \ n_channels, \ filter_name, \ small_filter) \ template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \ "_filter_" #filter_name)]] [[kernel]] void \ implicit_gemm_conv_2d( \ const device itype* A [[buffer(0)]], \ const device itype* B [[buffer(1)]], \ device itype* C [[buffer(2)]], \ const constant MLXConvParams<2>* params [[buffer(3)]], \ const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) #define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h ================================================ // Copyright © 2024 Apple Inc. #include using namespace metal; template < typename T, int BM, int BN, int BK, int WM, int WN, bool SMALL_FILTER = false> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_3d( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* C [[buffer(2)]], const constant MLXConvParams<3>* params [[buffer(3)]], const constant ImplicitGemmConv3DParams* gemm_params [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using namespace mlx::steel; (void)lid; constexpr bool transpose_a = false; constexpr bool transpose_b = true; constexpr short tgp_padding_a = 16 / sizeof(T); constexpr short tgp_padding_b = 16 / sizeof(T); constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; constexpr short shape_a_rows = (transpose_a ? BK : BM); constexpr short shape_b_rows = (transpose_b ? BN : BK); constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; constexpr short tgp_size = WM * WN * 32; // Input loader using loader_a_t = typename metal::conditional_t< // If the filter is small we can precompute masks for bounds checking SMALL_FILTER, Conv3DInputBlockLoaderSmallFilter, Conv3DInputBlockLoaderLargeFilter< T, BM, BN, BK, tgp_size, tgp_padding_a>>; // Weight loader using loader_b_t = Conv3DWeightBlockLoader; using mma_t = BlockMMA< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, shape_a_cols, shape_b_cols>; threadgroup T As[tgp_mem_size_a]; threadgroup T Bs[tgp_mem_size_b]; const int tid_y = ((tid.y) << gemm_params->swizzle_log) + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); const int tid_x = (tid.x) >> gemm_params->swizzle_log; if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { return; } const int c_row = tid_y * BM; const int c_col = tid_x * BN; const int K = gemm_params->K; const int N = gemm_params->N; const int C_per_group = params->C / params->groups; // Groups A += tid.z * C_per_group; B += tid.z * N * K; C += tid.z * N; B += c_col * K; C += c_row * (N * params->groups) + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); // Prepare threadgroup loading operations loader_a_t loader_a( A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); loader_b_t loader_b( B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); int gemm_k_iterations = gemm_params->gemm_k_iterations; for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory short tgp_bm = min(BM, gemm_params->M - c_row); short tgp_bn = min(BN, gemm_params->N - c_col); const int ldc = N * params->groups; mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); } ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.metal ================================================ // Copyright © 2024 Apple Inc. #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h" #define instantiate_implicit_conv_3d( \ name, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ fn, \ f) \ instantiate_kernel( \ "implicit_gemm_conv_3d_" #name "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn "_filter_" #fn, \ implicit_gemm_conv_3d, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ f) #define instantiate_implicit_conv_3d_filter(name, itype, bm, bn, bk, wm, wn) \ instantiate_implicit_conv_3d(name, itype, bm, bn, bk, wm, wn, s, true) \ instantiate_implicit_conv_3d(name, itype, bm, bn, bk, wm, wn, l, false) #define instantiate_implicit_3d_blocks(name, itype) \ instantiate_implicit_conv_3d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_conv_3d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_conv_3d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_conv_3d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_conv_3d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_conv_3d_filter(name, itype, 64, 64, 16, 2, 2) instantiate_implicit_3d_blocks(float32, float); instantiate_implicit_3d_blocks(float16, half); instantiate_implicit_3d_blocks(bfloat16, bfloat16_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" constant bool align_C [[function_constant(200)]]; template < typename T, int BM, int BN, int BK, int WM, int WN, typename AccumType = float, typename Epilogue = TransformNone> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* C [[buffer(2)]], const constant MLXConvParams<2>* params [[buffer(3)]], const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr bool transpose_a = false; constexpr bool transpose_b = true; constexpr short tgp_padding_a = 16 / sizeof(T); constexpr short tgp_padding_b = 16 / sizeof(T); constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; constexpr short shape_a_rows = (transpose_a ? BK : BM); constexpr short shape_b_rows = (transpose_b ? BN : BK); constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; constexpr short tgp_size = WM * WN * 32; // Input loader using loader_a_t = Conv2DInputBlockLoaderGeneral; // Weight loader using loader_b_t = Conv2DWeightBlockLoaderGeneral; using mma_t = BlockMMA< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, shape_a_cols, shape_b_cols>; threadgroup T As[tgp_mem_size_a]; threadgroup T Bs[tgp_mem_size_b]; const int tid_y = ((tid.y) << gemm_params->swizzle_log) + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); const int tid_x = (tid.x) >> gemm_params->swizzle_log; if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { return; } const int tid_z = tid.z; const int base_oh = tid_z / jump_params->f_out_jump_w; const int base_ow = tid_z % jump_params->f_out_jump_w; const int base_wh = base_h[base_oh].weight_base; const int base_ww = base_w[base_ow].weight_base; const int base_wh_size = base_h[base_oh].weight_size; const int base_ww_size = base_w[base_ow].weight_size; const int c_row = tid_y * BM; const int c_col = tid_x * BN; const int K = gemm_params->K; B += c_col * K; const int4 offsets_a(0, c_row, base_oh, base_ow); const int2 offsets_b(0, c_col); // Prepare threadgroup loading operations loader_a_t loader_a( A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); loader_b_t loader_b( B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); if (align_C) { int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } else { for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { for (int j = 0; j < base_wh_size * base_ww_size; j++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } const short remaining_k = params->C % BK; for (int j = 0; j < base_wh_size * base_ww_size; j++) { // Load elements into threadgroup threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(remaining_k); loader_b.load_safe(remaining_k); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory { // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; if (offset_n >= gemm_params->N) return; short diff = gemm_params->N - offset_n; STEEL_PRAGMA_UNROLL for (int i = 0; i < mma_t::TM; i++) { int cm = offset_m + i * mma_t::TM_stride; int n = cm / jump_params->adj_out_hw; int hw = cm % jump_params->adj_out_hw; int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2]; STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = mma_op.Ctile.frag_at(i, j); int offset = offset_cm + (j * mma_t::TN_stride); constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; // Apply epilogue and output C STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * mma_t::TN_stride + k) < diff) { C[offset + k] = Epilogue::apply(accum[k]); } } } } } } } ================================================ FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal ================================================ // Copyright © 2024 Apple Inc. #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h" using namespace metal; using namespace mlx::steel; #define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ template \ [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \ implicit_gemm_conv_2d_general( \ const device itype* A [[buffer(0)]], \ const device itype* B [[buffer(1)]], \ device itype* C [[buffer(2)]], \ const constant MLXConvParams<2>* params [[buffer(3)]], \ const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \ const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \ const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) #define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/conv/loader.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" ================================================ FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv2DInputBlockLoaderLargeFilter { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; short weight_h; short weight_w; const device T* src[n_rows]; int read_n[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; /* Constructor */ METAL_FUNC Conv2DInputBlockLoaderLargeFilter( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), gemm_params(gemm_params_), weight_h(0), weight_w(0) { int out_n_pixels = params->oS[0] * params->oS[1]; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_nhw = offsets.y + bi + i * TROWS; int n = offset_nhw / out_n_pixels; int hw = offset_nhw % out_n_pixels; int oh = hw / params->oS[1]; int ow = hw % params->oS[1]; int ih = oh * params->str[0] - params->pad[0]; int iw = ow * params->str[1] - params->pad[1]; read_n[i] = n; read_ih[i] = ih; read_iw[i] = iw; // Adjust for flip if (params->flip) { ih += (params->wS[0] - 1) * params->kdil[0]; iw += (params->wS[1] - 1) * params->kdil[1]; } // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + iw * params->in_strides[2] + bj; } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Find bounds int n = read_n[i]; int ih = read_ih[i] + weight_h * params->kdil[0]; int iw = read_iw[i] + weight_w * params->kdil[1]; // Read from input if in bounds if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && (iw >= 0 && iw < params->iS[1])) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = src[i][j]; } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_w < params->wS[1]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_w; } return; } weight_w = 0; if (++weight_h < params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_h; } return; } weight_h = 0; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_c; } } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv2DInputBlockLoaderSmallFilter { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; using mask_t = short; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; short weight_h; short weight_w; const device T* src[n_rows]; mask_t mask_h[n_rows]; mask_t mask_w[n_rows]; /* Constructor */ METAL_FUNC Conv2DInputBlockLoaderSmallFilter( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), gemm_params(gemm_params_), weight_h(0), weight_w(0) { int out_n_pixels = params->oS[0] * params->oS[1]; int read_n[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_nhw = offsets.y + bi + i * TROWS; int n = offset_nhw / out_n_pixels; int hw = offset_nhw % out_n_pixels; int oh = hw / params->oS[1]; int ow = hw % params->oS[1]; int ih = oh * params->str[0] - params->pad[0]; int iw = ow * params->str[1] - params->pad[1]; read_n[i] = n; read_ih[i] = ih; read_iw[i] = iw; // Adjust for flip if (params->flip) { ih += (params->wS[0] - 1) * params->kdil[0]; iw += (params->wS[1] - 1) * params->kdil[1]; } // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + iw * params->in_strides[2] + bj; } STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { mask_h[i] = 0; mask_w[i] = 0; } for (short kh = 0; kh < params->wS[0]; kh++) { short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int n = read_n[i]; int ih = read_ih[i] + flip_h * params->kdil[0]; bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; mask_h[i] |= (in_bounds << kh); } } for (short kw = 0; kw < params->wS[1]; kw++) { short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int iw = read_iw[i] + flip_w * params->kdil[1]; bool in_bounds = iw >= 0 && iw < params->iS[1]; mask_w[i] |= (in_bounds << kw); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { mask_t h_mask = mask_t(1) << weight_h; mask_t w_mask = mask_t(1) << weight_w; STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Read from input if in bounds if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = src[i][j]; } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_w < params->wS[1]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_w; } return; } weight_w = 0; if (++weight_h < params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_h; } return; } weight_h = 0; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_c; } } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv2DWeightBlockLoader { // Destination dimensions STEEL_CONST short BROWS = BN; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Leading dimension for src const int src_ld; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; const constant MLXConvParams<2>* params; int weight_hw; int weight_step; const int read_n; const bool do_read; /* Constructor */ METAL_FUNC Conv2DWeightBlockLoader( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BN; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = src[i * src_ld + j]; } } } else { for (short i = 0; i < BN; i += TROWS) { if ((read_n + i) < params->O) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = src[i * src_ld + j]; } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { src += weight_step; return; } weight_hw = 0; src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv3DInputBlockLoaderLargeFilter { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<3>* params; const constant ImplicitGemmConv3DParams* gemm_params; short weight_d; short weight_h; short weight_w; short kdil_d; short kdil_h; short kdil_w; const device T* src[n_rows]; int read_n[n_rows]; int read_id[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; /* Constructor */ METAL_FUNC Conv3DInputBlockLoaderLargeFilter( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<3>* params_, const constant ImplicitGemmConv3DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), gemm_params(gemm_params_), weight_d(0), weight_h(0), weight_w(0), kdil_d(params_->flip ? -params_->kdil[0] : params_->kdil[0]), kdil_h(params_->flip ? -params_->kdil[1] : params_->kdil[1]), kdil_w(params_->flip ? -params_->kdil[2] : params_->kdil[2]) { int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_ndhw = offsets.y + bi + i * TROWS; int n = offset_ndhw / out_n_pixels; int dhw = offset_ndhw % out_n_pixels; int od = dhw / (params->oS[1] * params->oS[2]); int hw = dhw % (params->oS[1] * params->oS[2]); int oh = hw / params->oS[2]; int ow = hw % params->oS[2]; int id = od * params->str[0] - params->pad[0]; int ih = oh * params->str[1] - params->pad[1]; int iw = ow * params->str[2] - params->pad[2]; read_n[i] = n; if (params->flip) { read_id[i] = id + (params->wS[0] - 1) * params->kdil[0]; read_ih[i] = ih + (params->wS[1] - 1) * params->kdil[1]; read_iw[i] = iw + (params->wS[2] - 1) * params->kdil[2]; } else { read_id[i] = id; read_ih[i] = ih; read_iw[i] = iw; } // Adjust for flip if (params->flip) { id += (params->wS[0] - 1) * params->kdil[0]; ih += (params->wS[1] - 1) * params->kdil[1]; iw += (params->wS[2] - 1) * params->kdil[2]; } // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + ih * params->in_strides[2] + iw * params->in_strides[3] + bj; } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Find bounds int n = read_n[i]; int id = read_id[i] + weight_d * kdil_d; int ih = read_ih[i] + weight_h * kdil_h; int iw = read_iw[i] + weight_w * kdil_w; // Read from input if in bounds if ((n < params->N) && (id >= 0 && id < params->iS[0]) && (ih >= 0 && ih < params->iS[1]) && (iw >= 0 && iw < params->iS[2])) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = src[i][j]; } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_w < params->wS[2]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_w; } return; } weight_w = 0; if (++weight_h < params->wS[1]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_h; } return; } weight_h = 0; if (++weight_d < params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_d; } return; } weight_d = 0; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_c; } } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv3DInputBlockLoaderSmallFilter { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; using mask_t = short; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<3>* params; const constant ImplicitGemmConv3DParams* gemm_params; short weight_d; short weight_h; short weight_w; const device T* src[n_rows]; mask_t mask_d[n_rows]; mask_t mask_h[n_rows]; mask_t mask_w[n_rows]; /* Constructor */ METAL_FUNC Conv3DInputBlockLoaderSmallFilter( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<3>* params_, const constant ImplicitGemmConv3DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), gemm_params(gemm_params_), weight_d(0), weight_h(0), weight_w(0) { int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; int read_n[n_rows]; int read_id[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_ndhw = offsets.y + bi + i * TROWS; int n = offset_ndhw / out_n_pixels; int dhw = offset_ndhw % out_n_pixels; int od = dhw / (params->oS[1] * params->oS[2]); int hw = dhw % (params->oS[1] * params->oS[2]); int oh = hw / params->oS[2]; int ow = hw % params->oS[2]; int id = od * params->str[0] - params->pad[0]; int ih = oh * params->str[1] - params->pad[1]; int iw = ow * params->str[2] - params->pad[2]; read_n[i] = n; read_id[i] = id; read_ih[i] = ih; read_iw[i] = iw; // Adjust for flip if (params->flip) { id += (params->wS[0] - 1) * params->kdil[0]; ih += (params->wS[1] - 1) * params->kdil[1]; iw += (params->wS[2] - 1) * params->kdil[2]; } // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + ih * params->in_strides[2] + iw * params->in_strides[3] + bj; } STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { mask_d[i] = 0; mask_h[i] = 0; mask_w[i] = 0; } for (short kd = 0; kd < params->wS[0]; kd++) { short flip_d = params->flip ? params->wS[0] - kd - 1 : kd; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int n = read_n[i]; int id = read_id[i] + flip_d * params->kdil[0]; bool in_bounds = n < params->N && id >= 0 && id < params->iS[0]; mask_d[i] |= (in_bounds << kd); } } for (short kh = 0; kh < params->wS[1]; kh++) { short flip_h = params->flip ? params->wS[1] - kh - 1 : kh; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int ih = read_ih[i] + flip_h * params->kdil[1]; bool in_bounds = ih >= 0 && ih < params->iS[1]; mask_h[i] |= (in_bounds << kh); } } for (short kw = 0; kw < params->wS[2]; kw++) { short flip_w = params->flip ? params->wS[2] - kw - 1 : kw; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int iw = read_iw[i] + flip_w * params->kdil[2]; bool in_bounds = iw >= 0 && iw < params->iS[2]; mask_w[i] |= (in_bounds << kw); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { mask_t d_mask = mask_t(1) << weight_d; mask_t h_mask = mask_t(1) << weight_h; mask_t w_mask = mask_t(1) << weight_w; STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Read from input if in bounds if ((mask_d[i] & d_mask) && (mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = src[i][j]; } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_w < params->wS[2]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_w; } return; } weight_w = 0; if (++weight_h < params->wS[1]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_h; } return; } weight_h = 0; if (++weight_d < params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_d; } return; } weight_d = 0; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += gemm_params->inp_jump_c; } } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv3DWeightBlockLoader { // Destination dimensions STEEL_CONST short BROWS = BN; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Leading dimension for src const int src_ld; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; const constant MLXConvParams<3>* params; int weight_dhw; int weight_step; const int read_n; const bool do_read; /* Constructor */ METAL_FUNC Conv3DWeightBlockLoader( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<3>* params_, const constant ImplicitGemmConv3DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj), params(params_), weight_dhw(0), weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BN; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = src[i * src_ld + j]; } } } else { for (short i = 0; i < BN; i += TROWS) { if ((read_n + i) < params->O) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = src[i * src_ld + j]; } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } } } /* Iteration helper */ METAL_FUNC void next() { if (++weight_dhw < (params->wS[0] * params->wS[1] * params->wS[2])) { src += weight_step; return; } weight_dhw = 0; src += BK - (params->wS[0] * params->wS[1] * params->wS[2] - 1) * weight_step; } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct ChannelHelper { STEEL_CONST short n_channels = n_channels_; STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; STEEL_CONST short excess = vec_size - n_channels_; }; template <> struct ChannelHelper<1> { STEEL_CONST short n_channels = 1; STEEL_CONST short vec_size = 1; STEEL_CONST short excess = 0; }; template <> struct ChannelHelper<2> { STEEL_CONST short n_channels = 2; STEEL_CONST short vec_size = 2; STEEL_CONST short excess = 0; }; template <> struct ChannelHelper<3> { STEEL_CONST short n_channels = 3; STEEL_CONST short vec_size = 4; STEEL_CONST short excess = 1; }; template <> struct ChannelHelper<4> { STEEL_CONST short n_channels = 4; STEEL_CONST short vec_size = 4; STEEL_CONST short excess = 0; }; template < typename T, short BM, short BN, short BK, short tgp_size, short n_channels, short tgp_padding = 0> struct Conv2DInputBlockLoaderSmallChannels { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = ChannelHelper::vec_size; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; int weight_hw; const device T* src[n_rows]; int read_n[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; /* Constructor */ METAL_FUNC Conv2DInputBlockLoaderSmallChannels( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), gemm_params(gemm_params_), weight_hw(thread_idx % TCOLS) { int out_n_pixels = params->oS[0] * params->oS[1]; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_nhw = offsets.y + bi + i * TROWS; int n = offset_nhw / out_n_pixels; int hw = offset_nhw % out_n_pixels; int oh = hw / params->oS[1]; int ow = hw % params->oS[1]; int ih = oh * params->str[0] - params->pad[0]; int iw = ow * params->str[1] - params->pad[1]; // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + iw * params->in_strides[2]; read_n[i] = n; read_ih[i] = ih; read_iw[i] = iw; } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { if (weight_hw >= params->wS[1] * params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } return; } int wh = (weight_hw / params->wS[1]); int ww = (weight_hw % params->wS[1]); int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; int weight_h = flip_h * params->kdil[0]; int weight_w = flip_w * params->kdil[1]; STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Find bounds int n = read_n[i]; int ih = read_ih[i] + weight_h; int iw = read_iw[i] + weight_w; // Read from input if in bounds if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && (iw >= 0 && iw < params->iS[1])) { const device T* curr_src = src[i] + weight_h * params->in_strides[1] + weight_w * params->in_strides[2]; STEEL_PRAGMA_UNROLL for (short j = 0; j < n_channels; ++j) { dst[is * dst_ld + j] = curr_src[j]; } STEEL_PRAGMA_UNROLL for (short j = n_channels; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { weight_hw += TCOLS; } }; template < typename T, short BM, short BN, short BK, short tgp_size, short n_channels, short tgp_padding = 0> struct Conv2DWeightBlockLoaderSmallChannels { // Destination dimensions STEEL_CONST short BROWS = BN; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = ChannelHelper::vec_size; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Leading dimension for src const int src_ld; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; const constant MLXConvParams<2>* params; int weight_hw; const int read_n; const bool do_read; /* Constructor */ METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld), params(params_), weight_hw(thread_idx % TCOLS), read_n(offsets.y + bi), do_read(read_n + BN <= gemm_params_->N) {} /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { if (bi >= BROWS || bj >= BCOLS) return; if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } return; } const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < n_channels; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } STEEL_PRAGMA_UNROLL for (short j = n_channels; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } else { for (short i = 0; i < BROWS; i += TROWS) { if (((read_n + i) < params->O)) { STEEL_PRAGMA_UNROLL for (short j = 0; j < n_channels; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } STEEL_PRAGMA_UNROLL for (short j = n_channels; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } } } /* Iteration helper */ METAL_FUNC void next() { weight_hw += TCOLS; } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv2DInputBlockLoaderGeneral { // Destination dimensions STEEL_CONST short BROWS = BM; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const constant MLXConvParams<2>* params; const constant Conv2DGeneralJumpParams* jump_params; const short base_wh; const short base_ww; short weight_h; short weight_w; const device T* src[n_rows]; int read_n[n_rows]; int read_ih[n_rows]; int read_iw[n_rows]; /* Constructor */ METAL_FUNC Conv2DInputBlockLoaderGeneral( const device T* src_, threadgroup T* dst_, const int4 offsets, const constant MLXConvParams<2>* params_, const constant Conv2DGeneralJumpParams* jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), params(params_), jump_params(jump_params_), base_wh(base_wh_), base_ww(base_ww_), weight_h(base_wh_), weight_w(base_ww_) { STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; ++i) { int offset_nhw = offsets.y + bi + i * TROWS; int n = offset_nhw / jump_params->adj_out_hw; int hw = offset_nhw % jump_params->adj_out_hw; int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; int ih = oh * params->str[0] - params->pad[0]; int iw = ow * params->str[1] - params->pad[1]; read_n[i] = n; read_ih[i] = ih; read_iw[i] = iw; // Read from input if in bounds src[i] = src_ + n * params->in_strides[0] + bj; } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Find bounds int n = read_n[i]; int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; int ih_dil = read_ih[i] + h_flip * params->kdil[0]; int iw_dil = read_iw[i] + w_flip * params->kdil[1]; int ih = ih_dil / params->idil[0]; int iw = iw_dil / params->idil[1]; size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; // Read from input if in bounds if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && (iw_dil >= 0 && iw < params->iS[1])) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = (src[i])[offset + j]; } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } METAL_FUNC void load_safe(const short remaining_k) const { STEEL_PRAGMA_UNROLL for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { // Find bounds int n = read_n[i]; int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; int ih_dil = read_ih[i] + h_flip * params->kdil[0]; int iw_dil = read_iw[i] + w_flip * params->kdil[1]; int ih = ih_dil / params->idil[0]; int iw = iw_dil / params->idil[1]; size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; // Read from input if in bounds if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && (iw_dil >= 0 && iw < params->iS[1])) { if (bj + vec_size <= remaining_k) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = (src[i])[offset + j]; } } else { for (short j = 0; j < vec_size; ++j) { if (bj + j < remaining_k) { dst[is * dst_ld + j] = (src[i])[offset + j]; } else { dst[is * dst_ld + j] = T(0); } } } } // Zero pad otherwise else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); } } } } /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; if (weight_w < params->wS[1]) { return; } weight_w = base_ww; weight_h += jump_params->f_wgt_jump_h; if (weight_h < params->wS[0]) { return; } weight_h = base_wh; STEEL_PRAGMA_UNROLL for (short i = 0; i < n_rows; i++) { src[i] += BK; } } }; template < typename T, short BM, short BN, short BK, short tgp_size, short tgp_padding = 0> struct Conv2DWeightBlockLoaderGeneral { // Destination dimensions STEEL_CONST short BROWS = BN; STEEL_CONST short BCOLS = BK; // Read dimensions STEEL_CONST short dst_ld = BCOLS + tgp_padding; STEEL_CONST short vec_size = (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); // Thread read shape STEEL_CONST short TCOLS = BCOLS / vec_size; STEEL_CONST short TROWS = tgp_size / TCOLS; // Rows / strided reads within the block STEEL_CONST short n_rows = BROWS / TROWS; // Leading dimension for src const int src_ld; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; const constant MLXConvParams<2>* params; const constant Conv2DGeneralJumpParams* jump_params; const short base_wh; const short base_ww; short weight_h; short weight_w; const int start_row; /* Constructor */ METAL_FUNC Conv2DWeightBlockLoaderGeneral( const device T* src_, threadgroup T* dst_, const int2 offsets, const constant MLXConvParams<2>* params_, const constant Conv2DGeneralJumpParams* jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj), params(params_), jump_params(jump_params_), base_wh(base_wh_), base_ww(base_ww_), weight_h(base_wh_), weight_w(base_ww_), start_row(offsets.y + bi) {} /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { const device T* curr_src = src + weight_h * params->wt_strides[1] + weight_w * params->wt_strides[2]; if ((start_row + BN <= params->O)) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BN; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } } else { for (short i = 0; i < BN; i += TROWS) { if ((start_row + i) < params->O) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } } } METAL_FUNC void load_safe(const short remaining_k) const { const device T* curr_src = src + weight_h * params->wt_strides[1] + weight_w * params->wt_strides[2]; if ((start_row + BN <= params->O)) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BN; i += TROWS) { if (bj + vec_size <= remaining_k) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } else { for (short j = 0; j < vec_size; j++) { if (bj + j < remaining_k) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } else { dst[i * dst_ld + j] = T(0); } } } } } else { for (short i = 0; i < BN; i += TROWS) { if ((start_row + i) < params->O) { if (bj + vec_size <= remaining_k) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } else { for (short j = 0; j < vec_size; j++) { if (bj + j < remaining_k) { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } else { dst[i * dst_ld + j] = T(0); } } } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } } } } /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; if (weight_w < params->wS[1]) { return; } weight_w = base_ww; weight_h += jump_params->f_wgt_jump_h; if (weight_h < params->wS[0]) { return; } weight_h = base_wh; src += BK; } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/conv/params.h ================================================ // Copyright © 2024 Apple Inc. #pragma once template struct MLXConvParams { int N; // Batch size int C; // In channels int O; // Out channels int iS[NDIM]; // Input spatial dim int wS[NDIM]; // Weight spatial dim int oS[NDIM]; // Output spatial dim int str[NDIM]; // Kernel strides int pad[NDIM]; // Input padding int kdil[NDIM]; // Kernel dilation int idil[NDIM]; // Input dilation int64_t in_strides[NDIM + 2]; // In strides int64_t wt_strides[NDIM + 2]; // Wt strides int64_t out_strides[NDIM + 2]; // Out strides int groups; // Input channel groups bool flip; static MLXConvParams with_padded_channels(MLXConvParams other, int pad_out, int pad_in) { MLXConvParams params = other; // Update strides for (int i = 0; i < NDIM + 1; i++) { params.in_strides[i] = (params.in_strides[i] / params.C) * (params.C + pad_in); params.wt_strides[i] = (params.wt_strides[i] / params.C) * (params.C + pad_in); params.out_strides[i] = (params.out_strides[i] / params.O) * (params.O + pad_out); } params.in_strides[NDIM + 1] = 1; params.wt_strides[NDIM + 1] = 1; params.out_strides[NDIM + 1] = 1; // Update channels params.C += pad_in; params.O += pad_out; return params; }; }; namespace mlx { namespace steel { struct ImplicitGemmConv2DParams { const int M; const int N; const int K; const int gemm_k_iterations; const int inp_jump_w; const int inp_jump_h; const int inp_jump_c; const int tiles_n; const int tiles_m; const int swizzle_log; }; struct ImplicitGemmConv3DParams { const int M; const int N; const int K; const int gemm_k_iterations; const int inp_jump_w; const int inp_jump_h; const int inp_jump_d; const int inp_jump_c; const int tiles_n; const int tiles_m; const int swizzle_log; }; struct Conv2DGeneralJumpParams { const int f_wgt_jump_h; const int f_wgt_jump_w; const int f_out_jump_h; const int f_out_jump_w; const int adj_out_h; const int adj_out_w; const int adj_out_hw; const int adj_implicit_m; }; struct Conv2DGeneralBaseInfo { int weight_base; int weight_size; }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/defines.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #define STEEL_CONST static constant constexpr const #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") #define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/gemm.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/gemm/loader.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// // GEMM kernel class /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct LoopAlignment {}; template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, typename AccumType = typename AccumHelper::accum_type, typename Epilogue = TransformNone> struct GEMMKernel { STEEL_CONST short tgp_padding_a = 16 / sizeof(T); STEEL_CONST short tgp_padding_b = 16 / sizeof(T); STEEL_CONST short tgp_mem_size_a = transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); STEEL_CONST short tgp_mem_size_b = transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; STEEL_CONST short tgp_size = WM * WN * 32; using loader_a_t = BlockLoader< T, transpose_a ? BK : BM, transpose_a ? BM : BK, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, !transpose_a, tgp_size>; using loader_b_t = BlockLoader< T, transpose_b ? BN : BK, transpose_b ? BK : BN, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, transpose_b, tgp_size>; using mma_t = BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, AccumType, Epilogue>; /* Main kernel function */ template static METAL_FUNC void gemm_loop( threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], const int gemm_k_iterations, thread loader_a_t& loader_a, thread loader_b_t& loader_b, thread mma_t& mma_op, thread const short& tgp_bm, thread const short& tgp_bn, thread const short& lbk, LoopAlignment l = {}) { // Appease the compiler (void)l; short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe(tile_dims_B); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } if (!K_aligned_) { threadgroup_barrier(mem_flags::mem_threadgroup); short2 tile_dims_A_last = transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); short2 tile_dims_B_last = transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); loader_a.load_safe(tile_dims_A_last); loader_b.load_safe(tile_dims_B_last); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } } /* Main kernel function */ static METAL_FUNC void run( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device U* D [[buffer(2)]], const constant GEMMParams* params [[buffer(3)]], threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // Pacifying compiler (void)lid; const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); int gemm_k_iterations = params->gemm_k_iterations_aligned; /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Loop tail if (!K_aligned) { int lbk = params->K - params->gemm_k_iterations_aligned * BK; short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } // Store results to device memory mma_op.store_result(D, params->ldd); return; } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { // Loop over K - unaligned case short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; if (tgp_bm == BM && tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result(D, params->ldd); return; } else if (tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else if (tgp_bm == BM) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } } } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/gemm_nax.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/gemm/nax.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; namespace mlx::steel { template < typename T, short SM, short SN, short SK, short BK, bool transpose_a, bool transpose_b, bool kAlignedM, bool kAlignedN, bool kAlignedK, typename AccumType = float> auto gemm_loop( const device T* A, const device T* B, int lda, int ldb, int K, int gemm_k_iterations_aligned, const short sgp_sm, const short sgp_sn) { constexpr short TM = SM / 16; constexpr short TN = SN / 16; constexpr short TK = SK / 16; constexpr int RA = transpose_a ? TK : TM; constexpr int CA = transpose_a ? TM : TK; constexpr int RB = transpose_b ? TN : TK; constexpr int CB = transpose_b ? TK : TN; NAXTile Dtile; Dtile.clear(); int gemm_k_iterations_ = gemm_k_iterations_aligned; STEEL_PRAGMA_NO_UNROLL for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { threadgroup_barrier(mem_flags::mem_none); STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; const int k = kk1; volatile int compiler_barrier; const int A_offset = transpose_a ? k * lda : k; const int B_offset = transpose_b ? k : k * ldb; if constexpr (kAlignedM) { Atile.load(A + A_offset, lda); } else { const short rmax = transpose_a ? SK : sgp_sm; const short cmax = transpose_a ? sgp_sm : SK; Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); } if constexpr (kAlignedN) { Btile.load(B + B_offset, ldb); } else { const short rmax = transpose_b ? sgp_sn : SK; const short cmax = transpose_b ? SK : sgp_sn; Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); } tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); (void)compiler_barrier; } A += transpose_a ? (BK * lda) : BK; B += transpose_b ? BK : (BK * ldb); } if constexpr (!kAlignedK) { simdgroup_barrier(mem_flags::mem_none); const short rem_bk = K - gemm_k_iterations_ * BK; STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { NAXTile Atile; NAXTile Btile; const int k = kk1; const short psk = max(0, rem_bk - k); const short2 Aklims = transpose_a ? short2(sgp_sm, psk) : short2(psk, sgp_sm); const short2 Bklims = transpose_b ? short2(psk, sgp_sn) : short2(sgp_sn, psk); const int A_offset = transpose_a ? k * lda : k; const int B_offset = transpose_b ? k : k * ldb; Atile.load_safe(A + A_offset, lda, Aklims); Btile.load_safe(B + B_offset, ldb, Bklims); tile_matmad_nax( Dtile, Atile, metal::bool_constant{}, Btile, metal::bool_constant{}); } } return Dtile; } } // namespace mlx::steel ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h ================================================ // Copyright © 2024 Apple Inc. using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// constant bool has_batch [[function_constant(10)]]; constant bool use_out_source [[function_constant(100)]]; constant bool do_axpby [[function_constant(110)]]; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; // clang-format off template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device T* C [[buffer(2), function_constant(use_out_source)]], device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on // Pacifying compiler (void)lid; using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, true, true, AccumType>; using loader_a_t = typename gemm_kernel::loader_a_t; using loader_b_t = typename gemm_kernel::loader_b_t; using mma_t = typename gemm_kernel::mma_t; // Find block const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; // Exit early if out of bounds if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } // Adjust for batch if (has_batch) { const constant auto* A_bstrides = batch_strides; const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); A += batch_offsets.x; B += batch_offsets.y; if (use_out_source) { const constant auto* C_bstrides = B_bstrides + params->batch_ndim; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } } else { A += params->batch_stride_a * tid.z; B += params->batch_stride_b * tid.z; if (use_out_source) { C += addmm_params->batch_stride_c * tid.z; } } D += params->batch_stride_d * tid.z; // Prepare threadgroup memory threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; if (use_out_source) { C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; } // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); // Prepare iterations int gemm_k_iterations = params->gemm_k_iterations_aligned; // Do unaligned K iterations first if (!align_K) { const int k_last = params->gemm_k_iterations_aligned * BK; const int k_remain = params->K - k_last; const size_t k_jump_a = transpose_a ? params->lda * size_t(k_last) : size_t(k_last); const size_t k_jump_b = transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); // Move loader source ahead to end loader_a.src += k_jump_a; loader_b.src += k_jump_b; // Load tile const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); // Do matmul mma_op.mma(As, Bs); // Reset source back to start loader_a.src -= k_jump_a; loader_b.src -= k_jump_b; } const TransformAdd epilogue_op_add( addmm_params->alpha, addmm_params->beta); const TransformAxpby epilogue_op_axpby( addmm_params->alpha, addmm_params->beta); /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (align_M && align_N) { // Do gemm for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Do epilogue if (use_out_source) { if (do_axpby) { mma_op.apply_epilogue( C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); } else { mma_op.apply_epilogue( C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); } } // Store results to device memory return mma_op.store_result(D, params->ldd); } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { // Loop over K - unaligned case const int leftover_bk = 0; if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { // Do gemm gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); // Do epilogue if (use_out_source) { if (do_axpby) { mma_op.apply_epilogue( C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); } else { mma_op.apply_epilogue( C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); } } // Store results to device memory return mma_op.store_result(D, params->ldd); } else if (align_N || tgp_bn == BN) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); // Do epilogue if (use_out_source) { if (do_axpby) { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_axpby); } else { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_add); } } // Store results to device memory return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } else if (align_M || tgp_bm == BM) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); // Do epilogue if (use_out_source) { if (do_axpby) { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_axpby); } else { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_add); } } // Store results to device memory return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } else { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); // Do epilogue if (use_out_source) { if (do_axpby) { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_axpby); } else { mma_op.apply_epilogue_safe( C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op_add); } } // Store results to device memory return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } } } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h" #define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gemm_fused_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h ================================================ // Copyright © 2025 Apple Inc. using namespace mlx::steel; constant bool has_batch [[function_constant(10)]]; constant bool use_out_source [[function_constant(100)]]; constant bool do_axpby [[function_constant(110)]]; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; // clang-format off template < bool kAlignedM, bool kAlignedN, class NAXTile_t, typename T> void gemm_epilogue( thread NAXTile_t& Dtile, const device T* C, const constant GEMMParams* params, const constant GEMMAddMMParams* addmm_params, const short sgp_sm, const short sgp_sn) { // clang-format on (void)params; using V = typename NAXTile_t::elem_type; constexpr short TM = NAXTile_t::kTileRows; constexpr short TN = NAXTile_t::kTileCols; constexpr short kElemsPerFrag = NAXTile_t::kElemsPerFrag; using CFrag = typename NAXTile_t::NAXFrag_t; using cfrag_t = typename CFrag::template dtype_frag_t; STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; mm++) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; nn++) { const short m = mm * CFrag::kFragRows; const short n = nn * CFrag::kFragCols; cfrag_t celems; if constexpr (kAlignedM && kAlignedN) { CFrag::load(celems, C, addmm_params->ldc, addmm_params->fdc, m, n); } else { CFrag::load_safe( celems, C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); } auto delems = Dtile.frag_at(mm, nn); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { if (do_axpby) { delems[i] = addmm_params->alpha * delems[i] + addmm_params->beta * static_cast(celems[i]); } else { delems[i] += static_cast(celems[i]); } } } } } // clang-format off template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device T* C [[buffer(2), function_constant(use_out_source)]], device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on // Find block const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; // Exit early if out of bounds if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } // Adjust for batch if (has_batch) { const constant auto* A_bstrides = batch_strides; const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); A += batch_offsets.x; B += batch_offsets.y; if (use_out_source) { const constant auto* C_bstrides = B_bstrides + params->batch_ndim; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } } else { A += params->batch_stride_a * tid.z; B += params->batch_stride_b * tid.z; if (use_out_source) { C += addmm_params->batch_stride_c * tid.z; } } D += params->batch_stride_d * tid.z; // Prepare threadgroup memory threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; if (use_out_source) { C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; } constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); const int sgp_sm_int = align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); const short sgp_sm = short(sgp_sm_int); const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const int sgp_sn_int = align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); const short sgp_sn = short(sgp_sn_int); const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); A += transpose_a ? tm : (tm * params->lda); B += transpose_b ? (tn * params->ldb) : tn; D += tm * params->ldd + tn; if (use_out_source) { C += tm * addmm_params->ldc + tn * addmm_params->fdc; } NAXTile Dtile; dispatch_bool(align_K, [&](auto kAlignedK) { dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { Dtile = gemm_loop< T, SM, SN, SK, BK, transpose_a, transpose_b, kAlignedM.value, kAlignedN.value, kAlignedK.value, AccumType>( A, B, params->lda, params->ldb, params->K, params->gemm_k_iterations_aligned, sgp_sm, sgp_sn); if (use_out_source) { gemm_epilogue( Dtile, C, params, addmm_params, sgp_sm, sgp_sn); } if constexpr (kAlignedM && kAlignedN) { Dtile.store(D, int(params->ldd)); } else { Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); } }); }); }); } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h" // clang-format off #define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 128, 64, 2, 4) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 128, 256, 2, 4) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 64, 4, 4) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 256, 4, 4) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h ================================================ // Copyright © 2024 Apple Inc. using namespace mlx::steel; constant bool has_batch [[function_constant(10)]]; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm_rhs( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device uint32_t* rhs_indices [[buffer(2)]], device T* C [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, true, true, AccumType>; using loader_a_t = typename gemm_kernel::loader_a_t; using loader_b_t = typename gemm_kernel::loader_b_t; using mma_t = typename gemm_kernel::mma_t; if (params->tiles_n <= static_cast(tid.x) || params->tiles_m <= static_cast(tid.y)) { return; } // Prepare threadgroup memory threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; // Find the block in A, B, C const int c_row = tid.y * BM; const int c_col = tid.x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; C += c_row_long * params->ldd + c_col_long; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = rhs_indices[c_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (rhs_indices[c_row + n] != index) { offset_next = n; index_next = rhs_indices[c_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b( B + index * params->batch_stride_b, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare iterations const int gemm_k_iterations = params->gemm_k_iterations_aligned; // Do unaligned K iterations first if (!align_K) { const int k_last = params->gemm_k_iterations_aligned * BK; const int k_remain = params->K - k_last; const size_t k_jump_a = transpose_a ? params->lda * size_t(k_last) : size_t(k_last); const size_t k_jump_b = transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); // Move loader source ahead to end loader_a.src += k_jump_a; loader_b.src += k_jump_b; // Load tile const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); // Do matmul mma_op.mma(As, Bs); // Reset source back to start loader_a.src -= k_jump_a; loader_b.src -= k_jump_b; } // Matrix level aligned never check if (align_M && align_N) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(C, params->ldd); } else { mma_op.store_result_slice( C, params->ldd, short2(0, offset), short2(BN, offset_next)); } } else { const short lbk = 0; // Tile aligned don't check if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); if (offset_next - offset == BM) { mma_op.store_result(C, params->ldd); } else { mma_op.store_result_slice( C, params->ldd, short2(0, offset), short2(BN, offset_next)); } } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_slice( C, params->ldd, short2(0, offset), short2(BN, offset_next)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_slice( C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); } // Nothing aligned so check both rows and cols else { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_slice( C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); } } } } template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device uint32_t* lhs_indices [[buffer(2)]], const device uint32_t* rhs_indices [[buffer(3)]], device T* C [[buffer(4)]], const constant GEMMParams* params [[buffer(5)]], const constant int* indices_shape [[buffer(6)]], const constant int64_t* lhs_strides [[buffer(7)]], const constant int64_t* rhs_strides [[buffer(8)]], const constant int& batch_ndim_a [[buffer(9)]], const constant int* batch_shape_a [[buffer(10)]], const constant int64_t* batch_strides_a [[buffer(11)]], const constant int& batch_ndim_b [[buffer(12)]], const constant int* batch_shape_b [[buffer(13)]], const constant int64_t* batch_strides_b [[buffer(14)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, true, true, AccumType>; using loader_a_t = typename gemm_kernel::loader_a_t; using loader_b_t = typename gemm_kernel::loader_b_t; using mma_t = typename gemm_kernel::mma_t; if (params->tiles_n <= static_cast(tid.x) || params->tiles_m <= static_cast(tid.y)) { return; } // Move A and B to the locations pointed by lhs_indices and rhs_indices. uint32_t indx_A, indx_B; if (has_batch) { ulong2 indices_offsets = elem_to_loc_broadcast( tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); indx_A = lhs_indices[indices_offsets.x]; indx_B = rhs_indices[indices_offsets.y]; } else { indx_A = lhs_indices[params->batch_stride_a * tid.z]; indx_B = rhs_indices[params->batch_stride_b * tid.z]; } A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); C += params->batch_stride_d * tid.z; // Prepare threadgroup memory threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; // Just make sure everybody's finished with the indexing math above. threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid.y * BM; const int c_col = tid.x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; C += c_row_long * params->ldd + c_col_long; // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); // Prepare iterations int gemm_k_iterations = params->gemm_k_iterations_aligned; // Do unaligned K iterations first if (!align_K) { const int k_last = params->gemm_k_iterations_aligned * BK; const int k_remain = params->K - k_last; const size_t k_jump_a = transpose_a ? params->lda * size_t(k_last) : size_t(k_last); const size_t k_jump_b = transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); // Move loader source ahead to end loader_a.src += k_jump_a; loader_b.src += k_jump_b; // Load tile const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); // Do matmul mma_op.mma(As, Bs); // Reset source back to start loader_a.src -= k_jump_a; loader_b.src -= k_jump_b; } // Matrix level aligned never check if (align_M && align_N) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } // Store results to device memory mma_op.store_result(C, params->ldd); } else { const short lbk = 0; // Tile aligned don't check if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result(C, params->ldd); } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } // Nothing aligned so check both rows and cols else { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, lbk, LoopAlignment{}); mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } } } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h" #define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn, \ gather_mm_rhs, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ float) #define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn, \ gather_mm, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ float) #define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \ instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on instantiate_gather_mm_shapes_helper(float16, half, float16, half); instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gather_mm_shapes_helper(float32, float, float32, float); ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h ================================================ // Copyright © 2024 Apple Inc. using namespace mlx::steel; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm_rhs_nax( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device uint32_t* rhs_indices [[buffer(2)]], device T* C [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; if (params->tiles_n <= static_cast(tid.x) || params->tiles_m <= static_cast(tid.y)) { return; } // Find the block in A, B, C const int c_row = tid.y * BM; const int c_col = tid.x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; C += c_row_long * params->ldd + c_col_long; rhs_indices += c_row; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); const int sgp_sm_int = align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); const short sgp_sm = short(sgp_sm_int); const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const int sgp_sn_int = align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); const short sgp_sn = short(sgp_sn_int); const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); A += transpose_a ? tm : (tm * params->lda); B += transpose_b ? (tn * params->ldb) : tn; C += tm * params->ldd + tn; rhs_indices += tm; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = rhs_indices[0]; short offset_next = 0; int n = 0; while (n < sgp_sm) { n++; offset = offset_next; index = index_next; offset_next = sgp_sm; for (; n < sgp_sm; n++) { if (rhs_indices[n] != index) { offset_next = n; index_next = rhs_indices[n]; break; } } threadgroup_barrier(mem_flags::mem_none); NAXTile Ctile; dispatch_bool(align_K, [&](auto kAlignedK) { dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { auto do_gemm = gemm_loop< // Matmul for partial BM, full BN and full K T, SM, SN, SK, BK, transpose_a, transpose_b, kAlignedM.value, kAlignedN.value, kAlignedK.value, AccumType>; Ctile = do_gemm( A, B + index * params->batch_stride_b, params->lda, params->ldb, params->K, params->gemm_k_iterations_aligned, sgp_sm, sgp_sn); if constexpr (kAlignedN.value) { if (offset_next - offset == SM) { Ctile.store(C, int(params->ldd)); } else { Ctile.store_slice( C, int(params->ldd), short2(0, offset), short2(SN, offset_next)); } } else { Ctile.store_slice( C, int(params->ldd), short2(0, offset), short2(sgp_sn, offset_next)); } }); }); }); } } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h" // clang-format off #define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn, \ gather_mm_rhs_nax, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ float) #define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \ instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \ instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4) // clang-format on instantiate_gather_mm_shapes_helper(float16, half, float16, half); instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/steel/defines.h" using namespace metal; using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// struct _NoMask { char x; constexpr METAL_FUNC operator bool() { return true; } constexpr METAL_FUNC operator bool() const threadgroup { return true; } constexpr METAL_FUNC operator bool() const device { return true; } constexpr METAL_FUNC operator bool() const constant { return true; } }; template struct ScaleOp { OutT scale; METAL_FUNC OutT apply(InT x) const { return static_cast(x) * scale; } }; typedef struct _NoMask nomask_t; template < typename T, typename out_mask_t, typename op_mask_t, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], const constant int64_t* batch_strides [[buffer(7)]], const device out_mask_t* out_mask [[buffer(10)]], const device op_mask_t* lhs_mask [[buffer(11)]], const device op_mask_t* rhs_mask [[buffer(12)]], const constant int* mask_strides [[buffer(13)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // Appease the compiler (void)lid; static_assert( BM == BN, "block_masked_gemm must have the same block M and block N size"); static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); constexpr bool has_operand_mask = !metal::is_same_v; constexpr bool has_output_mask = !metal::is_same_v; constexpr bool has_mul_operand_mask = has_operand_mask && !metal::is_same_v; constexpr bool has_mul_output_mask = has_output_mask && !metal::is_same_v; constexpr short k_mask_factor = short(BM / BK); using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } const constant auto* mask_batch_strides = batch_strides + 2 * params->batch_ndim; if (params->batch_ndim > 1) { if (has_output_mask) { out_mask += elem_to_loc( tid.z, batch_shape, mask_batch_strides, params->batch_ndim); mask_batch_strides += params->batch_ndim; } if (has_operand_mask) { const constant auto* mask_strides_lhs = mask_batch_strides; const constant auto* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim); lhs_mask += batch_offsets.x; rhs_mask += batch_offsets.y; } } else { if (has_output_mask) { out_mask += tid.z * mask_batch_strides[0]; mask_batch_strides += params->batch_ndim; } if (has_operand_mask) { lhs_mask += tid.z * mask_batch_strides[0]; rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; } } // Adjust for batch if (params->batch_ndim > 1) { const constant auto* A_bstrides = batch_strides; const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); A += batch_offsets.x; B += batch_offsets.y; } else { A += params->batch_stride_a * tid.z; B += params->batch_stride_b * tid.z; } D += params->batch_stride_d * tid.z; // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; const constant int* out_mask_strides = mask_strides; const constant int* lhs_mask_strides = mask_strides + (has_output_mask ? 2 : 0); const constant int* rhs_mask_strides = lhs_mask_strides + (has_operand_mask ? 2 : 0); const int out_mask_offset = !has_output_mask ? 0 : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; short k_factor_cnt = k_mask_factor; ScaleOp out_mask_op; ScaleOp lhs_mask_op; ScaleOp rhs_mask_op; if (has_output_mask) { auto mask_out = out_mask[out_mask_offset]; if (has_mul_output_mask) { out_mask_op.scale = float(mask_out); } // Write zeros and return if (!mask_out) { constexpr short tgp_size = WM * WN * 32; constexpr short vec_size = 4; // Tile threads in threadgroup constexpr short TN = BN / vec_size; constexpr short TM = tgp_size / TN; const short thread_idx = simd_group_id * 32 + simd_lane_id; const short bi = thread_idx / TN; const short bj = vec_size * (thread_idx % TN); D += bi * params->ldd + bj; short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { for (short ti = 0; ti < BM; ti += TM) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { D[ti * params->ldd + j] = T(0.); } } } else { short jmax = tgp_bn - bj; jmax = jmax < vec_size ? jmax : vec_size; for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { for (short j = 0; j < jmax; j++) { D[ti * params->ldd + j] = T(0.); } } } return; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; // Prepare threadgroup loading operations thread typename gemm_kernel::loader_a_t loader_a( A, params->lda, As, simd_group_id, simd_lane_id); thread typename gemm_kernel::loader_b_t loader_b( B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup bounds const short tgp_bm = MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); const short tgp_bn = MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); int gemm_k_iterations = params->gemm_k_iterations_aligned; /////////////////////////////////////////////////////////////////////////////// // Do unaligned K iterations first if (!K_aligned) { const int k_last = params->gemm_k_iterations_aligned * BK; const int mask_idx_last = k_last / BM; if (!has_operand_mask || (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { if (has_mul_operand_mask) { lhs_mask_op.scale = lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; rhs_mask_op.scale = rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; } // Move loader source ahead to end const int k_remain = params->K - k_last; const size_t k_jump_a = transpose_a ? params->lda * size_t(k_last) : size_t(k_last); const size_t k_jump_b = transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); loader_a.src += k_jump_a; loader_b.src += k_jump_b; // Load tile const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); if (has_mul_operand_mask) { loader_a.apply_inplace_op(lhs_mask_op); loader_b.apply_inplace_op(rhs_mask_op); } threadgroup_barrier(mem_flags::mem_threadgroup); // Do matmul mma_op.mma(As, Bs); // Reset source back to start loader_a.src -= k_jump_a; loader_b.src -= k_jump_b; } } /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { for (; gemm_k_iterations > 0; gemm_k_iterations--) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || (bool(lhs_mask[lhs_mask_offset]) && bool(rhs_mask[rhs_mask_offset]))) { if (has_mul_operand_mask) { lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; } // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); if (has_mul_operand_mask) { loader_a.apply_inplace_op(lhs_mask_op); loader_b.apply_inplace_op(rhs_mask_op); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); } // Prepare for next iteration loader_a.next(); loader_b.next(); k_factor_cnt--; lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; } if (has_mul_output_mask) { mma_op.apply_epilogue(out_mask_op); } // Store results to device memory mma_op.store_result(D, params->ldd); return; } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { const bool M_aligned = (tgp_bm == BM); const bool N_aligned = (tgp_bn == BN); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (; gemm_k_iterations > 0; gemm_k_iterations--) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || (bool(lhs_mask[lhs_mask_offset]) && bool(rhs_mask[rhs_mask_offset]))) { if (has_mul_operand_mask) { lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; } // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe(tile_dims_B); } if (has_mul_operand_mask) { loader_a.apply_inplace_op(lhs_mask_op); loader_b.apply_inplace_op(rhs_mask_op); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); } // Prepare for next iteration loader_a.next(); loader_b.next(); k_factor_cnt--; lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; } if (has_mul_output_mask) { mma_op.apply_epilogue(out_mask_op); } if (M_aligned && N_aligned) { mma_op.store_result(D, params->ldd); } else { mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } } } template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], const constant int64_t* batch_strides [[buffer(7)]], const device bool* out_mask [[buffer(10)]], const device bool* lhs_mask [[buffer(11)]], const device bool* rhs_mask [[buffer(12)]], const constant int* mask_strides [[buffer(13)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // Appease the compiler (void)lid; using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } if (params->batch_ndim > 1) { const constant auto* mask_batch_strides = batch_strides + 2 * params->batch_ndim; out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); if (has_operand_mask) { const constant auto* mask_strides_lhs = mask_batch_strides + params->batch_ndim; const constant auto* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim); lhs_mask += batch_offsets.x; rhs_mask += batch_offsets.y; } } else { out_mask += tid.z * batch_strides[2 * params->batch_ndim]; if (has_operand_mask) { lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; } } // Adjust for batch if (params->batch_ndim > 1) { const constant auto* A_bstrides = batch_strides; const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); A += batch_offsets.x; B += batch_offsets.y; } else { A += params->batch_stride_a * tid.z; B += params->batch_stride_b * tid.z; } D += params->batch_stride_d * tid.z; // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; // Write zeros and return if (!mask_out) { constexpr short tgp_size = WM * WN * 32; constexpr short vec_size = 4; // Tile threads in threadgroup constexpr short TN = BN / vec_size; constexpr short TM = tgp_size / TN; const short thread_idx = simd_group_id * 32 + simd_lane_id; const short bi = thread_idx / TN; const short bj = vec_size * (thread_idx % TN); D += bi * params->ldd + bj; short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { for (short ti = 0; ti < BM; ti += TM) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { D[ti * params->ldd + j] = T(0.); } } } else { short jmax = tgp_bn - bj; jmax = jmax < vec_size ? jmax : vec_size; for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { for (short j = 0; j < jmax; j++) { D[ti * params->ldd + j] = T(0.); } } } return; } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); int gemm_k_iterations = params->gemm_k_iterations_aligned; threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; // Prepare threadgroup loading operations thread typename gemm_kernel::loader_a_t loader_a( A, params->lda, As, simd_group_id, simd_lane_id); thread typename gemm_kernel::loader_b_t loader_b( B, params->ldb, Bs, simd_group_id, simd_lane_id); /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || (lhs_mask [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && rhs_mask [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); } // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Loop tail if (!K_aligned) { if (!has_operand_mask || (lhs_mask [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && rhs_mask [(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { int lbk = params->K - params->gemm_k_iterations_aligned * BK; short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } } // Store results to device memory mma_op.store_result(D, params->ldd); return; } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { // Loop over K - unaligned case short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); short lbk = params->K - params->gemm_k_iterations_aligned * BK; bool M_aligned = (tgp_bm == BM); bool N_aligned = (tgp_bn == BN); short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || (lhs_mask [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && rhs_mask [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe(tile_dims_B); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); } // Prepare for next iteration loader_a.next(); loader_b.next(); } if (!K_aligned) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || (lhs_mask [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && rhs_mask [(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { short2 tile_dims_A_last = transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); short2 tile_dims_B_last = transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); loader_a.load_safe(tile_dims_A_last); loader_b.load_safe(tile_dims_B_last); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } } if (M_aligned && N_aligned) { mma_op.store_result(D, params->ldd); } else { mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } } } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #define instantiate_gemm( \ outmaskname, \ outmasktype, \ opmaskname, \ opmasktype, \ tname, \ trans_a, \ trans_b, \ iname, \ itype, \ oname, \ otype, \ bm, \ bn, \ bk, \ wm, \ wn, \ aname, \ mn_aligned, \ kname, \ k_aligned) \ instantiate_kernel( \ "steel_gemm_block_outmask_" #outmaskname \ "_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ "_MN_" #aname "_K_" #kname, \ block_masked_gemm, \ itype, \ outmasktype, \ opmasktype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ mn_aligned, \ k_aligned) #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h ================================================ // Copyright © 2025 Apple Inc. using namespace mlx::steel; constant bool segments_contiguous [[function_constant(199)]]; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void segmented_mm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], const device uint32_t* segments [[buffer(2)]], device T* C [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { using gemm_kernel = GEMMKernel< T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, true, true, AccumType>; using loader_a_t = typename gemm_kernel::loader_a_t; using loader_b_t = typename gemm_kernel::loader_b_t; using mma_t = typename gemm_kernel::mma_t; if (params->tiles_n <= static_cast(tid.x) || params->tiles_m <= static_cast(tid.y)) { return; } // Prepare threadgroup memory threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; // Find the block in A, B, C const int c_row = tid.y * BM; const int c_col = tid.x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); // Move the pointers to the output tile A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; C += c_row_long * params->ldd + c_col_long; // Move the pointers to the start of the segment uint32_t k_start, k_end; if (segments_contiguous) { k_start = segments[2 * tid.z]; k_end = segments[2 * tid.z + 1]; } else { // We accept either contiguous (above) or weird strides where the beginning // of the next one is the previous one. Basically the last two strides are // both 1! k_start = segments[tid.z]; k_end = segments[tid.z + 1]; } A += transpose_a ? k_start * params->lda : k_start; B += transpose_b ? k_start : k_start * params->ldb; C += tid.z * params->batch_stride_d; // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Matrix level alignment so only check K if (align_M && align_N) { uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); if (k_remain > 0) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } mma_op.store_result(C, params->ldd); } else { // Tile aligned do the same as above if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); if (k_remain > 0) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } mma_op.store_result(C, params->ldd); } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_safe( transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); if (k_remain > 0) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_safe( transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); if (k_remain > 0) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } // Nothing aligned so check both rows and cols else { uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_safe( transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); loader_b.load_safe( transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); if (k_remain > 0) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); } } } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h" #define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ "_bk" #bk "_wm" #wm "_wn" #wn, \ segmented_mm, \ itype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ float) #define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on instantiate_segmented_mm_shapes_helper(float16, half, float16, half); instantiate_segmented_mm_shapes_helper( bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_segmented_mm_shapes_helper(float32, float, float32, float); ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h ================================================ // Copyright © 2024 Apple Inc. using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device U* C [[buffer(2)]], const constant GEMMSpiltKParams* params [[buffer(3)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { (void)lid; using gemm_kernel = GEMMKernel< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; using loader_a_t = typename gemm_kernel::loader_a_t; using loader_b_t = typename gemm_kernel::loader_b_t; using mma_t = typename gemm_kernel::mma_t; threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; const int tid_x = tid.x; const int tid_y = tid.y; const int tid_z = tid.z; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const int k_start = params->split_k_partition_size * tid_z; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); const size_t k_start_long = size_t(k_start); A += transpose_a ? (c_row_long + k_start_long * params->lda) : (k_start_long + c_row_long * params->lda); B += transpose_b ? (k_start_long + c_col_long * params->ldb) : (c_col_long + k_start_long * params->ldb); C += (size_t(params->split_k_partition_stride) * tid_z) + (c_row_long * params->ldc + c_col_long); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); int gemm_k_iterations = params->gemm_k_iterations_aligned; short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); short leftover_bk = params->K % BK; if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); } else if (tgp_bn == BN) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); } else if (tgp_bm == BM) { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); } else { gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); } threadgroup_barrier(mem_flags::mem_threadgroup); if ((tid_z + 1) == (params->split_k_partitions)) { int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK; if (!K_aligned || gemm_k_iter_remaining > 0) gemm_kernel::gemm_loop( As, Bs, gemm_k_iter_remaining, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk, LoopAlignment{}); } if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { mma_op.store_result(C, params->ldc); } else { mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); } } /////////////////////////////////////////////////////////////////////////////// // Split k accumulation kernel /////////////////////////////////////////////////////////////////////////////// template < typename AccT, typename OutT, typename Epilogue = TransformNone> [[kernel]] void gemm_splitk_accum( const device AccT* C_split [[buffer(0)]], device OutT* D [[buffer(1)]], const constant int& k_partitions [[buffer(2)]], const constant int& partition_stride [[buffer(3)]], const constant int& ldd [[buffer(4)]], uint2 gid [[thread_position_in_grid]]) { // Ajust D and C D += gid.x + gid.y * size_t(ldd); C_split += gid.x + gid.y * size_t(ldd); size_t offset = 0; AccT out = 0; for (int i = 0; i < k_partitions; i++) { out += C_split[offset]; offset += partition_stride; } // Write output D[0] = Epilogue::apply(out); } template < typename AccT, typename OutT, typename Epilogue = TransformAxpby> [[kernel]] void gemm_splitk_accum_axpby( const device AccT* C_split [[buffer(0)]], device OutT* D [[buffer(1)]], const constant int& k_partitions [[buffer(2)]], const constant int& partition_stride [[buffer(3)]], const constant int& ldd [[buffer(4)]], const device OutT* C [[buffer(5)]], const constant int& ldc [[buffer(6)]], const constant int& fdc [[buffer(7)]], const constant float& alpha [[buffer(8)]], const constant float& beta [[buffer(9)]], uint2 gid [[thread_position_in_grid]]) { // Ajust D and C C += gid.x * size_t(fdc) + gid.y * size_t(ldc); D += gid.x + gid.y * size_t(ldd); C_split += gid.x + gid.y * size_t(ldd); size_t offset = 0; AccT out = 0; for (int i = 0; i < k_partitions; i++) { out += C_split[offset]; offset += partition_stride; } // Write output Epilogue op(alpha, beta); D[0] = op.apply(out, *C); } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" #define instantiate_gemm( \ tname, \ trans_a, \ trans_b, \ iname, \ itype, \ oname, \ otype, \ bm, \ bn, \ bk, \ wm, \ wn, \ aname, \ mn_aligned, \ kname, \ k_aligned) \ instantiate_kernel( \ "steel_gemm_splitk_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ "_MN_" #aname "_K_" #kname, \ gemm_splitk, \ itype, \ otype, \ bm, \ bn, \ bk, \ wm, \ wn, \ trans_a, \ trans_b, \ mn_aligned, \ k_aligned) #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t); #define instantiate_accum(oname, otype, aname, atype) \ instantiate_kernel( \ "steel_gemm_splitk_accum_" #oname "_" #aname, \ gemm_splitk_accum, atype, otype) \ instantiate_kernel( \ "steel_gemm_splitk_accum_" #oname "_" #aname "_axbpy", \ gemm_splitk_accum_axpby, atype, otype) \ instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(float16, half, float32, float); instantiate_accum(float32, float, float32, float); instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h ================================================ // Copyright © 2026 Apple Inc. using namespace mlx::steel; constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; /////////////////////////////////////////////////////////////////////////////// // NAX Split-K GEMM kernel /////////////////////////////////////////////////////////////////////////////// // clang-format off template < typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk_nax( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device AccumType* C [[buffer(2)]], const constant GEMMSpiltKParams* params [[buffer(3)]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on const int linear_tid = tid.x; // Compute swizzled tile dimensions const int tn_swizzled = params->tiles_n << params->swizzle_log; const int tm_swizzled = (params->tiles_m + (1 << params->swizzle_log) - 1) >> params->swizzle_log; const int tiles_per_partition = tn_swizzled * tm_swizzled; const int tid_z = linear_tid / tiles_per_partition; const int xy_flat = linear_tid % tiles_per_partition; // Decode 2D grid coordinates in swizzled space const int grid_x = xy_flat % tn_swizzled; const int grid_y = xy_flat / tn_swizzled; // Apply X-Y swizzle const int tid_y = (grid_y << params->swizzle_log) + (grid_x & ((1 << params->swizzle_log) - 1)); const int tid_x = grid_x >> params->swizzle_log; // Exit early if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } // Calculate partition bounds const int c_row = tid_y * BM; const int c_col = tid_x * BN; const int k_start = params->split_k_partition_size * tid_z; const int k_end = min(k_start + params->split_k_partition_size, params->K); const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); const size_t k_start_long = size_t(k_start); // Adjust pointers for split-K partition A += transpose_a ? (c_row_long + k_start_long * params->lda) : (k_start_long + c_row_long * params->lda); B += transpose_b ? (k_start_long + c_col_long * params->ldb) : (c_col_long + k_start_long * params->ldb); C += (size_t(params->split_k_partition_stride) * tid_z) + (c_row_long * params->ldc + c_col_long); // NAX tile configuration constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; constexpr short TM = SM / 16; constexpr short TN = SN / 16; // Calculate simdgroup offsets and alignment const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); const int sgp_sm_int = align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); const short sgp_sm = short(sgp_sm_int); const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const int sgp_sn_int = align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); const short sgp_sn = short(sgp_sn_int); const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); A += transpose_a ? tm : (tm * params->lda); B += transpose_b ? (tn * params->ldb) : tn; C += tm * params->ldc + tn; NAXTile Dtile; // gemm_loop through the partition // Check K-alignment at runtime (partition-specific) const int partition_k_size = k_end - k_start; const int partition_k_iters = partition_k_size / BK; const bool partition_k_aligned = (partition_k_size % BK) == 0; dispatch_bool(partition_k_aligned, [&](auto kAlignedK) { dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { Dtile = gemm_loop< T, SM, SN, SK, BK, transpose_a, transpose_b, kAlignedM.value, kAlignedN.value, kAlignedK.value, AccumType>( A, B, params->lda, params->ldb, partition_k_size, partition_k_iters, sgp_sm, sgp_sn); }); }); }); // Store result dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { if constexpr (kAlignedM && kAlignedN) { Dtile.store(C, int(params->ldc)); } else { Dtile.store_safe(C, int(params->ldc), short2(sgp_sn, sgp_sm)); } }); }); } ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal ================================================ // Copyright © 2026 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h" // clang-format off #define instantiate_gemm_splitk(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_kernel( \ "steel_gemm_splitk_nax_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ gemm_splitk_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) #define instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_splitk(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_splitk(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_splitk(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_splitk(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_splitk_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4) instantiate_gemm_splitk_shapes_helper(float16, half, float32, float); instantiate_gemm_splitk_shapes_helper(bfloat16, bfloat, float32, float); instantiate_gemm_splitk_shapes_helper(float32, float, float32, float); // clang-format on ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/loader.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short alignment = 1, short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; // Leading dimension for src const int src_ld; const int tile_stride; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; struct alignas(alignment * sizeof(T)) ReadVector { uint8_t v[sizeof(T) * vec_size]; }; /* Constructor */ METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} /* Apply operation to threadgroup without bound checking */ template METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { *((threadgroup ReadVector*)(&dst[i * dst_ld])) = *((const device ReadVector*)(&src[i * src_ld])); } } /* Load from device memory into threadgroup memory - with bound checking */ METAL_FUNC void load_safe(short2 src_tile_dim) const { src_tile_dim = src_tile_dim - short2(bj, bi); // Skip loading if thread has no valid reads if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } return; } // Use fast thread memory for bound checks bool tmp_idx[vec_size]; T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { // Make sure tmp_idx only contains valid indices STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); } // Read valid indices into tmp_val STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); } // Copy values to threadgroup memory STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = tmp_val[j]; } } } /* Iteration helper */ METAL_FUNC void next() { src += tile_stride; } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/mma.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct BaseMMAFrag { static_assert( kFragRows_ == 8, "Only 8 x 8 fragment matrices are currently supported"); static_assert( kFragCols_ == 8, "Only 8 x 8 fragment matrices are currently supported"); }; template struct BaseMMAFrag { STEEL_CONST int kFragRows = 8; STEEL_CONST int kFragCols = 8; STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; STEEL_CONST int kElemRows = 1; STEEL_CONST int kElemCols = 2; static_assert( kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); typedef metal::simdgroup_matrix mat_type; typedef metal::vec frag_type; METAL_FUNC static constexpr short2 get_coord( ushort simd_lane_id [[thread_index_in_simdgroup]]) { const short qid = simd_lane_id / 4; const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; return short2{fn, fm}; } template METAL_FUNC static constexpr void load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); } } } template < typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX, typename OffY> METAL_FUNC static constexpr void load_safe( thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[i * kElemCols + j] = static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } } } } template METAL_FUNC static constexpr void store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { using U = pointer_element_t; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); } } } template < typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX, typename OffY> METAL_FUNC static constexpr void store_safe( const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[(off_x + i) * str_x + (off_y + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename DstPtrType, typename StrX, typename StrY, typename StartX, typename StopX, typename StartY, typename StopY, typename OffX, typename OffY> METAL_FUNC static constexpr void store_slice( const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y, StartX start_x, StopX stop_x, StartY start_y, StopY stop_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < stop_x && (off_x + i) >= start_x && (off_y + j) < stop_y && (off_y + j) >= start_y) { dst[(off_x + i) * str_x + (off_y + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } METAL_FUNC static constexpr void mma( thread frag_type& D, thread frag_type& A, thread frag_type& B, thread frag_type& C) { mat_type D_mat; mat_type A_mat; mat_type B_mat; mat_type C_mat; reinterpret_cast(A_mat.thread_elements()) = A; reinterpret_cast(B_mat.thread_elements()) = B; reinterpret_cast(C_mat.thread_elements()) = C; mma(D_mat, A_mat, B_mat, C_mat); D = reinterpret_cast(D_mat.thread_elements()); } METAL_FUNC static constexpr void mma( thread mat_type& D, thread mat_type& A, thread mat_type& B, thread mat_type& C) { simdgroup_multiply_accumulate(D, A, B, C); } }; template < typename T, int kTileRows_, int kTileCols_, class MMAFrag_ = BaseMMAFrag> struct MMATile { using MMAFrag_t = MMAFrag_; using elem_type = T; STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; STEEL_CONST int kTileRows = kTileRows_; STEEL_CONST int kTileCols = kTileCols_; STEEL_CONST int kRows = kTileRows * kFragRows; STEEL_CONST int kCols = kTileCols * kFragCols; STEEL_CONST int kNumFrags = kTileRows * kTileCols; STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::frag_type frag_type; frag_type val_frags[kNumFrags] = {frag_type(0)}; METAL_FUNC MMATile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kNumFrags; ++i) { val_frags[i] = frag_type(0); } } METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { return val_frags[i * kTileCols + j]; } METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { return val_frags[i * kTileCols + j]; } METAL_FUNC mat_type mat_at(const short i, const short j) { mat_type val_mat; STEEL_PRAGMA_UNROLL for (short ii = 0; ii < kElemsPerFrag; ++ii) { val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; } return val_mat; } METAL_FUNC thread elem_type* elems() { return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { return reinterpret_cast(val_frags); } template METAL_FUNC void load(const threadgroup U* src) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::load( frag_at(i, j), &( src[(i * kFragRows) * w_x * str_x + (j * kFragCols) * w_y * str_y]), Int{}, Int{}); } } } template METAL_FUNC void store(threadgroup U* dst) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::store( frag_at(i, j), &( dst[(i * kFragRows) * w_x * str_x + (j * kFragCols) * w_y * str_y]), Int{}, Int{}); } } } template METAL_FUNC void load(const device U* src, const int ld) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::load( frag_at(i, j), &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld, Int<1>{}); } } } template METAL_FUNC void store(device U* dst, const int ld) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { MMAFrag_t::store( frag_at(i, j), &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld, Int<1>{}); } } } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { MMAFrag_t::load_safe( frag_at(i, j), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, (i * kFragRows) * w_x, (j * kFragCols) * w_y); } } } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { MMAFrag_t::store_safe( frag_at(i, j), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, (i * kFragRows) * w_x, (j * kFragCols) * w_y); } } } template METAL_FUNC void store_slice( device U* dst, const int ld, const short2 start, const short2 stop) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { MMAFrag_t::store_slice( frag_at(i, j), dst, ld, Int<1>{}, start.y, stop.y, start.x, stop.x, (i * kFragRows) * w_x, (j * kFragCols) * w_y); } } } }; template METAL_FUNC void tile_matmad( thread MMATile& D, thread MMATile& A, thread MMATile& B, thread MMATile& C) { STEEL_PRAGMA_UNROLL for (short m = 0; m < M; ++m) { STEEL_PRAGMA_UNROLL for (short n = 0; n < N; ++n) { short n_serp = (m % 2) ? (N - 1 - n) : n; STEEL_PRAGMA_UNROLL for (short k = 0; k < K; ++k) { MMATile::MMAFrag_t::mma( D.frag_at(m, n_serp), A.frag_at(m, k), B.frag_at(k, n_serp), C.frag_at(m, n_serp)); } } } } template struct TransformNone { static METAL_FUNC complex64_t apply(complex64_t x) { return x; } static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { return x; } }; template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, short lda_tgp, short ldb_tgp, typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { // MMAFrag size STEEL_CONST short kFragSize = 8; using MMAFrag_acc_t = BaseMMAFrag; // Warp tile simdgroup matrix strides along M STEEL_CONST short TM_stride = kFragSize * WM; // Warp tile simdgroup matrix strides along M STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M STEEL_CONST short TM = BM / (kFragSize * WM); // Warp tile size along N STEEL_CONST short TN = BN / (kFragSize * WN); // Threadgroup A strides STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K // Threadgroup B strides STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N // Threadgroup strides along K STEEL_CONST short tile_stride_a = kFragSize * A_str_k; STEEL_CONST short tile_stride_b = kFragSize * B_str_k; // Simdgroup matrices MMATile Atile; MMATile Btile; MMATile Ctile; // Offsets within threadgroup short sm; short sn; short As_offset; short Bs_offset; /* Constructor */ METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) { // Determine thread position in simdgroup matrix short tm = kFragSize * (simd_group_id / WN); short tn = kFragSize * (simd_group_id % WN); short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); sm = simd_coord.y; sn = simd_coord.x; // Determine thread and simdgroup offset As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N sm += tm; sn += tn; } /* (BM, BK) X (BK, BN) multiply accumulate function */ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { // Adjust for simdgroup and thread location As += As_offset; Bs += Bs_offset; // Iterate over BK in blocks of kFragSize STEEL_PRAGMA_UNROLL for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); tile_matmad(Ctile, Atile, Btile, Ctile); // Progress to next simdgroup tile As += tile_stride_a; Bs += tile_stride_b; } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result(device U* D, const int ldd) { // Apply epilogue STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } // Adjust for simdgroup and thread location D += sm * ldd + sn; Ctile.template store(D, ldd); } METAL_FUNC void store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { // Apply epilogue STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } D += sm * ldd + sn; start -= short2(sn, sm); stop -= short2(sn, sm); // TODO: Check the start as well if (stop.y <= 0 || stop.x <= 0) { return; } Ctile.template store_slice(D, ldd, start, stop); } METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { // Apply epilogue STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } // Adjust for simdgroup and thread location D += sm * ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; Ctile.template store_safe(D, ldd, dst_tile_dims); } /* Apply epilogue */ template METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue( const device U* C, const int ldc, const int fdc, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue_safe( const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Read C U c_elems[kelems] = {0}; STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { c_elems[k] = C[offset_c + k * fdc]; } } // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { accum[k] = epilogue_op.apply(accum[k], c_elems[k]); } } } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } METAL_FUNC void store_result_safe( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; constexpr short kelems = decltype(Ctile)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } } } }; template < typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, short lda_tgp, short ldb_tgp, typename AccumType, typename Epilogue> struct BlockMMA< complex64_t, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue> { static_assert( metal::is_same_v, "BlockMMA expects float accumulators"); static_assert( metal::is_same_v, "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections"); // MMAFrag size STEEL_CONST short kFragSize = 8; using MMAFrag_acc_t = BaseMMAFrag; // Warp tile simdgroup matrix strides along M STEEL_CONST short TM_stride = kFragSize * WM; // Warp tile simdgroup matrix strides along M STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M STEEL_CONST short TM = BM / (kFragSize * WM); // Warp tile size along N STEEL_CONST short TN = BN / (kFragSize * WN); // Threadgroup A strides STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K // Threadgroup B strides STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N // Threadgroup strides along K STEEL_CONST short tile_stride_a = kFragSize * A_str_k; STEEL_CONST short tile_stride_b = kFragSize * B_str_k; // When indexing complex as float[2] STEEL_CONST short A_str_m_f = A_str_m * 2; STEEL_CONST short A_str_k_f = A_str_k * 2; STEEL_CONST short B_str_k_f = B_str_k * 2; STEEL_CONST short B_str_n_f = B_str_n * 2; STEEL_CONST short tile_stride_a_f = tile_stride_a * 2; STEEL_CONST short tile_stride_b_f = tile_stride_b * 2; // Accumulators (real/imag) MMATile Ctile_r; MMATile Ctile_i; // Offsets within threadgroup short sm, sn; short As_offset, Bs_offset; /* Constructor */ METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) { // Determine thread position in simdgroup matrix short tm = kFragSize * (simd_group_id / WN); short tn = kFragSize * (simd_group_id % WN); short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); sm = simd_coord.y; sn = simd_coord.x; // Determine thread and simdgroup offset As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K) Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N) sm += tm; sn += tn; } /* Karatsuba MMA: 3 real MMAs per K-chunk */ METAL_FUNC void mma( const threadgroup complex64_t* As, const threadgroup complex64_t* Bs) { // Adjust for simdgroup and thread location As += As_offset; Bs += Bs_offset; threadgroup const float* As_f = reinterpret_cast(As); threadgroup const float* Bs_f = reinterpret_cast(Bs); // Iterate over BK in blocks of kFragSize STEEL_PRAGMA_UNROLL for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); MMATile Ar, Ai; Ar.template load(As_f + 0); Ai.template load(As_f + 1); simdgroup_barrier(mem_flags::mem_none); MMATile Br, Bi; Br.template load(Bs_f + 0); Bi.template load(Bs_f + 1); simdgroup_barrier(mem_flags::mem_none); // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi) MMATile P, Q, R; tile_matmad(P, Ar, Br, P); tile_matmad(Q, Ai, Bi, Q); STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i) Ar.elems()[i] += Ai.elems()[i]; STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i) Br.elems()[i] += Bi.elems()[i]; tile_matmad(R, Ar, Br, R); // C_r += P - Q ; C_i -= Q STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) { const auto p = P.elems()[i]; const auto q = Q.elems()[i]; const auto r = R.elems()[i]; Ctile_r.elems()[i] += (p - q); Ctile_i.elems()[i] += (r - p - q); } // Progress to next simdgroup tile As_f += tile_stride_a_f; Bs_f += tile_stride_b_f; } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result(device U* D, const int ldd) { // Adjust for simdgroup and thread location D += sm * ldd + sn; STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { thread const auto& r = Ctile_r.frag_at(i, j); thread const auto& im = Ctile_i.frag_at(i, j); int off = (i * TM_stride) * ldd + (j * TN_stride); STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); } } } } METAL_FUNC void store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { D += sm * ldd + sn; start -= short2(sn, sm); stop -= short2(sn, sm); if (stop.y <= 0 || stop.x <= 0) return; STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; ++i) { const int row = i * TM_stride; if (row >= start.y && row < stop.y) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; ++j) { const int off = row * ldd + (j * TN_stride); thread const auto& r = Ctile_r.frag_at(i, j); thread const auto& im = Ctile_i.frag_at(i, j); STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) { const int col = j * TN_stride + k; if (col >= start.x && col < stop.x) { D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); } } } } } } METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { D += sm * ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { int off = (i * TM_stride) * ldd + (j * TN_stride); thread const auto& r = Ctile_r.frag_at(i, j); thread const auto& im = Ctile_i.frag_at(i, j); STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); } } } } } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { STEEL_PRAGMA_UNROLL for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) { complex64_t out = epilogue_op.apply( complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i])); Ctile_r.elems()[i] = out.real; Ctile_i.elems()[i] = out.imag; } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue( const device U* C, const int ldc, const int fdc, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in Cr, Ci thread auto& r = Ctile_r.frag_at(i, j); thread auto& im = Ctile_i.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; STEEL_PRAGMA_UNROLL for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { complex64_t out = epilogue_op.apply( complex64_t(r[k], im[k]), C[offset_c + k * fdc]); r[k] = out.real; im[k] = out.imag; } } } } /* Apply epilogue */ template METAL_FUNC void apply_epilogue_safe( const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in Cr, Ci thread auto& r = Ctile_r.frag_at(i, j); thread auto& im = Ctile_i.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; complex64_t tmp[kelems]; STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x && (i * TM_stride) < dst_tile_dims.y) { tmp[k] = C[offset_c + k * fdc]; } else { tmp[k] = complex64_t(0.0f, 0.0f); } } // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]); r[k] = out.real; im[k] = out.imag; } } } } /* Store results from simdgroup_matrix results into device memory */ METAL_FUNC void store_result( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in Cr, Ci thread const auto& r = Ctile_r.frag_at(i, j); thread const auto& im = Ctile_i.frag_at(i, j); int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int off_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { D[off_d + k] = epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]); } } } } METAL_FUNC void store_result_safe( device U* D, const int ldd, const device U* C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location C += (sm)*ldc + (sn)*fdc; D += (sm)*ldd + sn; dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in Cr, Ci thread const auto& r = Ctile_r.frag_at(i, j); thread const auto& im = Ctile_i.frag_at(i, j); int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int off_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue STEEL_PRAGMA_UNROLL for (short k = 0; k < kelems; k++) { if ((j * TN_stride + k) < dst_tile_dims.x) { D[off_d + k] = epilogue_op.apply( complex64_t(r[k], im[k]), C[off_c + k * fdc]); } } } } } } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/nax.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" #include using namespace metal; /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { /////////////////////////////////////////////////////////////////////////////// // NAX Steel with new tiles /////////////////////////////////////////////////////////////////////////////// struct BaseNAXFrag { STEEL_CONST short kFragRows = 16; STEEL_CONST short kFragCols = 16; STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; STEEL_CONST short kElemRows = 2; STEEL_CONST short kElemCols = 4; STEEL_CONST short kElemRowsJump = 8; static_assert( kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); template using dtype_frag_t = typename metal::vec; METAL_FUNC static short2 get_coord() { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; return short2{fn, fm}; } METAL_FUNC static short2 get_coord(short idx) { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; return short2{fn, fm}; } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_rows( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = T(0); } } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_safe( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; auto ly = lim_y - sc.x; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((r < lx) && ((c + j) < ly)) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_rows( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_safe( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = get_coord(); dst += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; auto ly = lim_y - sc.x; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump; const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if (r < lx && (c + j) < ly) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename StartX, typename StopX, typename StartY, typename StopY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_slice( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, StartX start_x, StopX stop_x, StartY start_y, StopY stop_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; const short2 sc = get_coord(); const_for_loop<0, kElemRows, 1>([&](auto idx_row) { const auto r = off_x + idx_row * Int{}; if (r >= stop_x - sc.y || r < start_x - sc.y) { return; } const_for_loop<0, kElemCols, 1>([&](auto idx_col) { const auto c = off_y + idx_col; if (c >= stop_y - sc.x || c < start_y - sc.x) { return; } const auto src_idx = idx_row * Int{} + idx_col; dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = static_cast(src[src_idx]); }); }); } template METAL_FUNC static constexpr void row_reduce( thread const dtype_frag_t& inp_vals, thread T* reduced_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { T thr_reduce = Op::apply( Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); qgr_reduce = Op::apply(thr_reduce, qgr_reduce); T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); } } template METAL_FUNC static constexpr void row_bin_op( thread dtype_frag_t& inp_vals, thread T* row_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { inp_vals[i * kElemCols + j] = Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); } } } template < typename CType, typename AType, typename BType, bool transpose_a = false, bool transpose_b = false> METAL_FUNC static constexpr void mma( thread dtype_frag_t& Cn0, thread dtype_frag_t& Cn1, const thread dtype_frag_t& A, metal::bool_constant, const thread dtype_frag_t& Bn0, const thread dtype_frag_t& Bn1, metal::bool_constant) { constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( 16, 32, 16, transpose_a, transpose_b, true, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); // Create matmul op mpp::tensor_ops::matmul2d gemm_op; // Create matmul operands in registers auto ct_a = gemm_op .template get_left_input_cooperative_tensor(); auto ct_b = gemm_op .template get_right_input_cooperative_tensor(); // Create matmul output in register auto ct_c = gemm_op.template get_destination_cooperative_tensor< decltype(ct_a), decltype(ct_b), CType>(); // Load A in to left operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_a[i] = A[i]; } // Load B into right operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_b[i] = Bn0[i]; ct_b[kElemsPerFrag + i] = Bn1[i]; } // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_c[i] = Cn0[i]; ct_c[kElemsPerFrag + i] = Cn1[i]; } // Do matmul gemm_op.run(ct_a, ct_b, ct_c); // Copy out results STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { Cn0[i] = ct_c[i]; Cn1[i] = ct_c[kElemsPerFrag + i]; } } template < typename CType, typename AType, typename BType, bool transpose_a = false, bool transpose_b = false> METAL_FUNC static constexpr void mma( thread dtype_frag_t& Cm0, thread dtype_frag_t& Cm1, const thread dtype_frag_t& Am0, const thread dtype_frag_t& Am1, metal::bool_constant, const thread dtype_frag_t& B, metal::bool_constant) { // Create Matmul descriptor constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( 16, 32, 16, transpose_a, transpose_b, true, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); // Create matmul op mpp::tensor_ops::matmul2d gemm_op; // Create matmul operands in registers auto ct_a = gemm_op .template get_left_input_cooperative_tensor(); auto ct_b = gemm_op .template get_right_input_cooperative_tensor(); // Create matmul output in register auto ct_c = gemm_op.template get_destination_cooperative_tensor< decltype(ct_a), decltype(ct_b), CType>(); // Load A in to left operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_a[i] = Am0[i]; ct_a[kElemsPerFrag + i] = Am1[i]; } // Load B into right operand registers STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_b[i] = B[i]; } // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { ct_c[i] = Cm0[i]; ct_c[kElemsPerFrag + i] = Cm1[i]; } // Do matmul gemm_op.run(ct_a, ct_b, ct_c); // Copy out results STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemsPerFrag; i++) { Cm0[i] = ct_c[i]; Cm1[i] = ct_c[kElemsPerFrag + i]; } } }; template < typename T, short kTileRows_, short kTileCols_, class NAXFrag_ = BaseNAXFrag> struct NAXTile { using NAXFrag_t = NAXFrag_; using elem_type = T; STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; STEEL_CONST short kTileRows = kTileRows_; STEEL_CONST short kTileCols = kTileCols_; STEEL_CONST short kRows = kTileRows * kFragRows; STEEL_CONST short kCols = kTileCols * kFragCols; STEEL_CONST short kNumFrags = kTileRows * kTileCols; STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag; STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows; STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols; typedef typename NAXFrag_t::template dtype_frag_t frag_type; frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC NAXTile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kNumFrags; ++i) { val_frags[i] = frag_type(0); } } METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { return val_frags[i * kTileCols + j]; } METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr thread frag_type& frag_at() { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr const thread frag_type& frag_at() const { return val_frags[i * kTileCols + j]; } template METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j, metal::bool_constant) { if constexpr (transpose) { return frag_at(j, i); } else { return frag_at(i, j); } } template METAL_FUNC constexpr const thread frag_type& frag_at(const short i, const short j, metal::bool_constant) const { if constexpr (transpose) { return frag_at(j, i); } else { return frag_at(i, j); } } template METAL_FUNC constexpr thread frag_type& frag_at() { if constexpr (transpose) { return frag_at(); } else { return frag_at(); } } template METAL_FUNC constexpr const thread frag_type& frag_at() const { if constexpr (transpose) { return frag_at(); } else { return frag_at(); } } METAL_FUNC thread elem_type* elems() { return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { NAXFrag_t::template row_reduce( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { NAXFrag_t::template row_bin_op( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void load(const threadgroup U* src) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load( frag_at(), src, Int{}, Int{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store(threadgroup U* dst) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store( frag_at(), dst, Int{}, Int{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load(const device U* src, const int ld) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load( frag_at(), src, ld, Int<1>{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store(device U* dst, const int ld) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store( frag_at(), dst, ld, Int<1>{}, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load_rows(const device U* src, const int ld, const short n_rows) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load_rows( frag_at(), src, ld, Int<1>{}, n_rows, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::load_safe( frag_at(), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_rows( frag_at(), dst, ld, Int<1>{}, n_rows, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_safe( frag_at(), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, idx_row * Int{}, idx_col * Int{}); }); }); } template METAL_FUNC void store_slice( device U* dst, const int ld, const short2 start, const short2 stop) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_slice( frag_at(), dst, ld, Int<1>{}, start.y, stop.y, start.x, stop.x, idx_row * Int{}, idx_col * Int{}); }); }); } }; template < class CTile, class ATile, class BTile, bool transpose_a, bool transpose_b> METAL_FUNC void tile_matmad_nax( thread CTile& C, thread ATile& A, metal::bool_constant, thread BTile& B, metal::bool_constant) { // Static checks constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; constexpr short TM = CTile::kTileRows; static_assert(TMa == TM, "MXU tile matmul: M dimensions do not match"); constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; constexpr short TN = CTile::kTileCols; static_assert(TNb == TN, "MXU tile matmul: N dimensions do not match"); constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows; static_assert(TKa == TK, "MXU tile matmul: K dimensions do not match"); constexpr auto ta = metal::bool_constant{}; constexpr auto tb = metal::bool_constant{}; if constexpr (TN == 1 && TM % 2 == 0) { STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; mm += 2) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; ++nn) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; ++kk) { CTile::NAXFrag_t::mma( C.frag_at(mm, nn), C.frag_at(mm + 1, nn), A.frag_at(mm, kk, ta), A.frag_at(mm + 1, kk, ta), metal::bool_constant{}, B.frag_at(kk, nn, tb), metal::bool_constant{}); } } } } else if constexpr (TN % 2 == 0) { STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; ++mm) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; nn += 2) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; ++kk) { CTile::NAXFrag_t::mma( C.frag_at(mm, nn), C.frag_at(mm, nn + 1), A.frag_at(mm, kk, ta), metal::bool_constant{}, B.frag_at(kk, nn, tb), B.frag_at(kk, nn + 1, tb), metal::bool_constant{}); } } } } } } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/params.h ================================================ // Copyright © 2024 Apple Inc. #pragma once /////////////////////////////////////////////////////////////////////////////// // GEMM param classes /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { struct GEMMParams { const int M; const int N; const int K; const int lda; const int ldb; const int ldd; const int tiles_n; const int tiles_m; const int64_t batch_stride_a; const int64_t batch_stride_b; const int64_t batch_stride_d; const int swizzle_log; const int gemm_k_iterations_aligned; const int batch_ndim; }; struct GEMMSpiltKParams { const int M; const int N; const int K; const int lda; const int ldb; const int ldc; const int tiles_n; const int tiles_m; const int split_k_partitions; const int split_k_partition_stride; const int split_k_partition_size; const int swizzle_log; const int gemm_k_iterations_aligned; }; struct GEMMAddMMParams { const int ldc; const int fdc; const int64_t batch_stride_c; const float alpha; const float beta; }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/gemm/transforms.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/kernels/steel/utils.h" /////////////////////////////////////////////////////////////////////////////// // Transforms and Epilogues /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct TransformNone { static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT) { return static_cast(x); } }; template struct TransformAdd { TransformAdd(const float, const float) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT c) { return static_cast(x) + c; } }; template struct TransformAxpby { const float alpha; const float beta; TransformAxpby(const float alpha_, const float beta_) : alpha(alpha_), beta(beta_) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } METAL_FUNC OutT apply(InT x, OutT c) const { return static_cast( x * static_cast(alpha) + (static_cast(beta) * c)); } }; template struct AccumHelper { typedef float accum_type; }; struct BlockSwizzle { static METAL_FUNC int2 swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { const int tid_x = (tid.x) >> swizzle_log; const int tid_y = ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); return int2(tid_x, tid_y); } }; } // namespace steel } // namespace mlx ================================================ FILE: mlx/backend/metal/kernels/steel/utils/integral_constant.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/backend/metal/kernels/steel/utils/type_traits.h" #pragma METAL internals : enable namespace mlx { namespace steel { /////////////////////////////////////////////////////////////////////////////// // Integral constant with casting /////////////////////////////////////////////////////////////////////////////// template struct integral_constant { static constexpr constant T value = v; using value_type = T; using type = integral_constant; METAL_FUNC constexpr operator value_type() const noexcept { return value; } // METAL_FUNC constexpr value_type operator()() const noexcept { // return value; // } }; template using bool_constant = integral_constant; using true_type = bool_constant; using false_type = bool_constant; template struct is_integral : bool_constant::value> {}; template struct is_integral> : bool_constant::value> {}; template constexpr constant bool is_integral_v = is_integral::value; template using Int = integral_constant; /////////////////////////////////////////////////////////////////////////////// // Binary Operators on Integral constants /////////////////////////////////////////////////////////////////////////////// #define integral_const_binop(__op__, __operator__) \ template \ METAL_FUNC constexpr auto __operator__( \ integral_constant, integral_constant) { \ constexpr auto res = tv __op__ uv; \ return integral_constant{}; \ } integral_const_binop(+, operator+); integral_const_binop(-, operator-); integral_const_binop(*, operator*); integral_const_binop(/, operator/); integral_const_binop(==, operator==); integral_const_binop(!=, operator!=); integral_const_binop(<, operator<); integral_const_binop(>, operator>); integral_const_binop(<=, operator<=); integral_const_binop(>=, operator>=); integral_const_binop(&&, operator&&); integral_const_binop(||, operator||); template >> METAL_FUNC constexpr auto operator||(true_type, T) { return true_type{}; } template >> METAL_FUNC constexpr auto operator||(T, true_type) { return true_type{}; } template >> METAL_FUNC constexpr auto operator&&(false_type, T) { return false_type{}; } template >> METAL_FUNC constexpr auto operator&&(T, false_type) { return false_type{}; } // Dispatch utilities template void dispatch_bool(bool v, F f) { if (v) { f(true_type{}); } else { f(false_type{}); } } template constexpr void const_for_loop(F f) { if constexpr (start < stop) { constexpr auto idx = Int{}; f(idx); const_for_loop(f); } } #undef integral_const_binop /////////////////////////////////////////////////////////////////////////////// // Reduction operators /////////////////////////////////////////////////////////////////////////////// template METAL_FUNC constexpr T sum(T x) { return x; } template METAL_FUNC constexpr auto sum(T x, Us... us) { return x + sum(us...); } } // namespace steel } // namespace mlx #pragma METAL internals : disable ================================================ FILE: mlx/backend/metal/kernels/steel/utils/type_traits.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #pragma METAL internals : enable namespace metal { template struct is_empty : metal::bool_constant<__is_empty(T)> {}; #ifdef __cpp_variable_templates template constexpr constant bool is_empty_v = is_empty::value; #endif template struct make_void { typedef void type; }; template using void_t = typename make_void::type; template struct is_static : metal::bool_constant>::value> {}; template struct pointer_element {}; template struct pointer_element { using type = remove_cv_t; }; template struct pointer_element { using type = remove_cv_t; }; template struct pointer_element { using type = remove_cv_t; }; template struct pointer_element { using type = remove_cv_t; }; template using pointer_element_t = typename pointer_element>::type; } // namespace metal #pragma METAL internals : disable ================================================ FILE: mlx/backend/metal/kernels/steel/utils.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include METAL_FUNC ulong2 elem_to_loc_broadcast( uint elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { int pos_in_dim = (elem % shape[i]); elem /= shape[i]; loc_a += pos_in_dim * a_strides[i]; loc_b += pos_in_dim * b_strides[i]; } return ulong2(loc_a, loc_b); } METAL_FUNC ulong3 elem_to_loc_broadcast( uint elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; ulong loc_c{0}; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { int pos_in_dim = (elem % shape[i]); elem /= shape[i]; loc_a += pos_in_dim * a_strides[i]; loc_b += pos_in_dim * b_strides[i]; loc_c += pos_in_dim * c_strides[i]; } return ulong3(loc_a, loc_b, loc_c); } ================================================ FILE: mlx/backend/metal/kernels/ternary.h ================================================ // Copyright © 2024 Apple Inc. template < typename T, typename Op, bool BSCALAR, bool CSCALAR, int N = WorkPerThread::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { auto bidx = BSCALAR ? 0 : index + i; auto cidx = CSCALAR ? 0 : index + i; d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); } } else { for (int i = 0; i < N; ++i) { auto bidx = BSCALAR ? 0 : index + i; auto cidx = CSCALAR ? 0 : index + i; d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); } } } template < typename T, typename Op, bool BSCALAR, bool CSCALAR, int N = WorkPerThread::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { auto bidx = BSCALAR ? 0 : offset + i; auto cidx = CSCALAR ? 0 : offset + i; d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); } } else { for (int i = 0; i < N; ++i) { auto bidx = BSCALAR ? 0 : offset + i; auto cidx = CSCALAR ? 0 : offset + i; d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); } } } template [[kernel]] void ternary_g_nd1( device const bool* a, device const T* b, device const T* c, device T* d, constant const int64_t& a_strides, constant const int64_t& b_strides, constant const int64_t& c_strides, uint index [[thread_position_in_grid]]) { auto a_idx = elem_to_loc_1(index, a_strides); auto b_idx = elem_to_loc_1(index, b_strides); auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } template [[kernel]] void ternary_g_nd2( device const bool* a, device const T* b, device const T* c, device T* d, constant const int64_t a_strides[2], constant const int64_t b_strides[2], constant const int64_t c_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); auto c_idx = elem_to_loc_2(index, c_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } template [[kernel]] void ternary_g_nd3( device const bool* a, device const T* b, device const T* c, device T* d, constant const int64_t a_strides[3], constant const int64_t b_strides[3], constant const int64_t c_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); auto c_idx = elem_to_loc_3(index, c_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } template [[kernel]] void ternary_g( device const bool* a, device const T* b, device const T* c, device T* d, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_3_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, c_strides, ndim); auto xshape = shape[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); IdxT a_xstride = a_strides[ndim - 1]; IdxT b_xstride = b_strides[ndim - 1]; IdxT c_xstride = c_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); idx.x += a_xstride; idx.y += b_xstride; idx.z += c_xstride; } } ================================================ FILE: mlx/backend/metal/kernels/ternary.metal ================================================ // Copyright © 2024 Apple Inc. #include #include // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" #define instantiate_ternary_base(op, tname, type) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \ instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \ instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \ instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \ instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \ instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ #define instantiate_ternary_all(op, tname, type) \ instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \ instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \ instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \ instantiate_ternary_base(op, tname, type) #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint32, uint32_t) \ instantiate_ternary_base(op, uint64, uint64_t) \ instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int32, int32_t) \ instantiate_ternary_base(op, int64, int64_t) \ instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \ instantiate_ternary_base(op, complex64, complex64_t) // clang-format on instantiate_ternary_types(Select) ================================================ FILE: mlx/backend/metal/kernels/ternary_ops.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once struct Select { template T operator()(bool condition, T x, T y) { return condition ? x : y; } }; ================================================ FILE: mlx/backend/metal/kernels/unary.h ================================================ // Copyright © 2024 Apple Inc. template ::n> [[kernel]] void unary_v( device const T* in, device U* out, constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { out[index + i] = static_cast(Op()(in[index + i])); } } else { for (int i = 0; i < N; ++i) { out[index + i] = static_cast(Op()(in[index + i])); } } } template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { out[offset + i] = static_cast(Op()(in[offset + i])); } } else { for (int i = 0; i < N; ++i) { out[offset + i] = static_cast(Op()(in[offset + i])); } } } template < typename T, typename U, typename Op, int N = 1, typename IdxT = int64_t> [[kernel]] void unary_g( device const T* in, device U* out, constant const int* in_shape, constant const int64_t* in_strides, device const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc( {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); auto xshape = in_shape[ndim - 1]; IdxT xstride = in_strides[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { out[out_idx++] = static_cast(Op()(in[idx])); idx += xstride; } } ================================================ FILE: mlx/backend/metal/kernels/unary.metal ================================================ // Copyright © 2024 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" #define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \ instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) #define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ instantiate_kernel( \ "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ instantiate_kernel( \ "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) #define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) #define instantiate_unary_base_same(op, tname, type) \ instantiate_unary_base(op, tname, tname, type, type) #define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) #define instantiate_unary_int(op) \ instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \ instantiate_unary_all_same(op, uint32, uint32_t) \ instantiate_unary_base_same(op, uint64, uint64_t) \ instantiate_unary_all_same(op, int8, int8_t) \ instantiate_unary_all_same(op, int16, int16_t) \ instantiate_unary_all_same(op, int32, int32_t) \ instantiate_unary_base_same(op, int64, int64_t) #define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ instantiate_unary_int(op) \ instantiate_unary_float(op) instantiate_unary_types(Abs) instantiate_unary_float(ArcCos) instantiate_unary_float(ArcCosh) instantiate_unary_float(ArcSin) instantiate_unary_float(ArcSinh) instantiate_unary_float(ArcTan) instantiate_unary_float(ArcTanh) instantiate_unary_types(Ceil) instantiate_unary_float(Cos) instantiate_unary_float(Cosh) instantiate_unary_float(Exp) instantiate_unary_float(Expm1) instantiate_unary_types(Floor) instantiate_unary_float(Log) instantiate_unary_float(Log2) instantiate_unary_float(Log10) instantiate_unary_float(Log1p) instantiate_unary_types(Negative) instantiate_unary_float(Sigmoid) instantiate_unary_float(Erf) instantiate_unary_float(ErfInv) instantiate_unary_types(Sign) instantiate_unary_float(Sin) instantiate_unary_float(Sinh) instantiate_unary_types(Square) instantiate_unary_float(Sqrt) instantiate_unary_float(Rsqrt) instantiate_unary_float(Tan) instantiate_unary_float(Tanh) instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) instantiate_unary_base_same(Abs, complex64, complex64_t) instantiate_unary_base_same(ArcCos, complex64, complex64_t) instantiate_unary_base_same(ArcSin, complex64, complex64_t) instantiate_unary_base_same(ArcTan, complex64, complex64_t) instantiate_unary_base_same(Conjugate, complex64, complex64_t) instantiate_unary_base_same(Cos, complex64, complex64_t) instantiate_unary_base_same(Cosh, complex64, complex64_t) instantiate_unary_base_same(Exp, complex64, complex64_t) instantiate_unary_base_same(Log, complex64, complex64_t) instantiate_unary_base_same(Log1p, complex64, complex64_t) instantiate_unary_base_same(Log2, complex64, complex64_t) instantiate_unary_base_same(Log10, complex64, complex64_t) instantiate_unary_base_same(Negative, complex64, complex64_t) instantiate_unary_base_same(Sign, complex64, complex64_t) instantiate_unary_base_same(Sin, complex64, complex64_t) instantiate_unary_base_same(Sinh, complex64, complex64_t) instantiate_unary_base_same(Square, complex64, complex64_t) instantiate_unary_base_same(Sqrt, complex64, complex64_t) instantiate_unary_base_same(Rsqrt, complex64, complex64_t) instantiate_unary_base_same(Tan, complex64, complex64_t) instantiate_unary_base_same(Tanh, complex64, complex64_t) instantiate_unary_base_same(Round, complex64, complex64_t) instantiate_unary_base(Real, complex64, float32, complex64_t, float) instantiate_unary_base(Imag, complex64, float32, complex64_t, float) instantiate_unary_all_same(LogicalNot, bool_, bool) instantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t) instantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t) instantiate_unary_all(ToFP8, float32, uint8, float, uint8_t) instantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t) instantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t) instantiate_unary_all(FromFP8, uint8, float32, uint8_t, float) // clang-format on ================================================ FILE: mlx/backend/metal/kernels/unary_ops.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" #include "mlx/backend/metal/kernels/fp8.h" namespace { constant float inf = metal::numeric_limits::infinity(); } struct Abs { template T operator()(T x) { return metal::abs(x); }; uint8_t operator()(uint8_t x) { return x; }; uint16_t operator()(uint16_t x) { return x; }; uint32_t operator()(uint32_t x) { return x; }; uint64_t operator()(uint64_t x) { return x; }; bool operator()(bool x) { return x; }; complex64_t operator()(complex64_t x) { return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; }; }; struct ArcCos { template T operator()(T x) { return metal::precise::acos(x); }; complex64_t operator()(complex64_t x); }; struct ArcCosh { template T operator()(T x) { return metal::precise::acosh(x); }; }; struct ArcSin { template T operator()(T x) { return metal::precise::asin(x); }; complex64_t operator()(complex64_t x); }; struct ArcSinh { template T operator()(T x) { return metal::precise::asinh(x); }; }; struct ArcTan { template T operator()(T x) { return metal::precise::atan(x); }; complex64_t operator()(complex64_t x); }; struct ArcTanh { template T operator()(T x) { return metal::precise::atanh(x); }; }; struct BitwiseInvert { template T operator()(T x) { return ~x; }; }; struct Ceil { template T operator()(T x) { return metal::ceil(x); }; int8_t operator()(int8_t x) { return x; }; int16_t operator()(int16_t x) { return x; }; int32_t operator()(int32_t x) { return x; }; int64_t operator()(int64_t x) { return x; }; uint8_t operator()(uint8_t x) { return x; }; uint16_t operator()(uint16_t x) { return x; }; uint32_t operator()(uint32_t x) { return x; }; uint64_t operator()(uint64_t x) { return x; }; bool operator()(bool x) { return x; }; }; struct Cos { template T operator()(T x) { return metal::precise::cos(x); }; complex64_t operator()(complex64_t x) { return { metal::precise::cos(x.real) * metal::precise::cosh(x.imag), -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; }; }; struct Cosh { template T operator()(T x) { return metal::precise::cosh(x); }; complex64_t operator()(complex64_t x) { return { metal::precise::cosh(x.real) * metal::precise::cos(x.imag), metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; }; }; struct Conjugate { complex64_t operator()(complex64_t x) { return complex64_t{x.real, -x.imag}; } }; struct Erf { template T operator()(T x) { return static_cast(erf(static_cast(x))); }; }; struct ErfInv { template T operator()(T x) { return static_cast(erfinv(static_cast(x))); }; }; struct Exp { template T operator()(T x) { return metal::precise::exp(x); }; complex64_t operator()(complex64_t x) { return cexpf(x); } }; struct Expm1 { template T operator()(T x) { return static_cast(expm1f(static_cast(x))); }; }; struct Floor { template T operator()(T x) { return metal::floor(x); }; int8_t operator()(int8_t x) { return x; }; int16_t operator()(int16_t x) { return x; }; int32_t operator()(int32_t x) { return x; }; int64_t operator()(int64_t x) { return x; }; uint8_t operator()(uint8_t x) { return x; }; uint16_t operator()(uint16_t x) { return x; }; uint32_t operator()(uint32_t x) { return x; }; uint64_t operator()(uint64_t x) { return x; }; bool operator()(bool x) { return x; }; }; struct Imag { float operator()(complex64_t x) { return x.imag; }; }; struct Log { template T operator()(T x) { return metal::precise::log(x); }; complex64_t operator()(complex64_t x) { auto r = metal::precise::log(Abs{}(x).real); auto i = metal::precise::atan2(x.imag, x.real); return {r, i}; }; }; struct Log2 { template T operator()(T x) { return metal::precise::log2(x); }; complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN2_F, y.imag / M_LN2_F}; }; }; struct Log10 { template T operator()(T x) { return metal::precise::log10(x); }; complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN10_F, y.imag / M_LN10_F}; }; }; struct Log1p { template T operator()(T x) { return log1p(x); }; }; struct LogicalNot { template T operator()(T x) { return !x; }; }; struct Negative { template T operator()(T x) { return -x; }; }; struct Real { float operator()(complex64_t x) { return x.real; }; }; struct Round { template T operator()(T x) { return metal::rint(x); }; complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; }; struct Sigmoid { template T operator()(T x) { auto y = 1 / (1 + metal::exp(metal::abs(x))); return (x < 0) ? y : 1 - y; } }; struct Sign { template T operator()(T x) { return (x > T(0)) - (x < T(0)); }; uint32_t operator()(uint32_t x) { return x != 0; }; complex64_t operator()(complex64_t x) { if (x == complex64_t(0)) { return x; } return x / (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); }; }; struct Sin { template T operator()(T x) { return metal::precise::sin(x); }; complex64_t operator()(complex64_t x) { return { metal::precise::sin(x.real) * metal::precise::cosh(x.imag), metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; }; }; struct Sinh { template T operator()(T x) { return metal::precise::sinh(x); }; complex64_t operator()(complex64_t x) { return { metal::precise::sinh(x.real) * metal::precise::cos(x.imag), metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; }; }; struct Square { template T operator()(T x) { return x * x; }; }; struct Sqrt { template T operator()(T x) { return metal::precise::sqrt(x); }; complex64_t operator()(complex64_t x) { if (x.real == 0.0 && x.imag == 0.0) { return {0.0, 0.0}; } auto r = Abs{}(x).real; auto a = metal::precise::sqrt((r + x.real) / 2.0); auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); auto b = metal::copysign(b_abs, x.imag); return {a, b}; } }; struct Rsqrt { template T operator()(T x) { return metal::precise::rsqrt(x); }; complex64_t operator()(complex64_t x) { return 1.0 / Sqrt{}(x); } }; struct Tan { template T operator()(T x) { return metal::precise::tan(x); }; complex64_t operator()(complex64_t x) { float tan_a = metal::precise::tan(x.real); float tanh_b = metal::precise::tanh(x.imag); float t1 = tan_a * tanh_b; float denom = 1. + t1 * t1; return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; }; }; struct Tanh { template T operator()(T x) { return metal::precise::tanh(x); }; complex64_t operator()(complex64_t x) { float tanh_a = metal::precise::tanh(x.real); float tan_b = metal::precise::tan(x.imag); float t1 = tanh_a * tan_b; float denom = 1. + t1 * t1; return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; }; }; complex64_t ArcCos::operator()(complex64_t x) { auto i = complex64_t{0.0, 1.0}; auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); return {y.imag, -y.real}; }; complex64_t ArcSin::operator()(complex64_t x) { auto i = complex64_t{0.0, 1.0}; auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); return {y.imag, -y.real}; }; complex64_t ArcTan::operator()(complex64_t x) { auto i = complex64_t{0.0, 1.0}; auto ix = i * x; return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); }; struct ToFP8 { template uint8_t operator()(T f) { return fp8_e4m3(f).bits; } }; struct FromFP8 { float operator()(uint8_t x) { return float(*(thread fp8_e4m3*)(&x)); } }; ================================================ FILE: mlx/backend/metal/kernels/utils.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16_math.h" #include "mlx/backend/metal/kernels/complex.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/logging.h" typedef half float16_t; // Work per thread values for different types. The values here are expected to // match get_work_per_thread in mlx/backend/metal/utils.h template struct WorkPerThread { static_assert(sizeof(U) <= 8, "Type too large"); static constexpr int constant n = 8 / sizeof(U); }; /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// template struct Limits { static const constant U max = metal::numeric_limits::max(); static const constant U min = metal::numeric_limits::min(); static const constant U finite_max = metal::numeric_limits::max(); static const constant U finite_min = metal::numeric_limits::min(); }; #define instantiate_default_limit(type) \ template <> \ struct Limits { \ static constexpr constant type max = metal::numeric_limits::max(); \ static constexpr constant type min = metal::numeric_limits::min(); \ static constexpr constant type finite_max = \ metal::numeric_limits::max(); \ static constexpr constant type finite_min = \ metal::numeric_limits::min(); \ }; instantiate_default_limit(uint8_t); instantiate_default_limit(uint16_t); instantiate_default_limit(uint32_t); instantiate_default_limit(uint64_t); instantiate_default_limit(int8_t); instantiate_default_limit(int16_t); instantiate_default_limit(int32_t); instantiate_default_limit(int64_t); #define instantiate_float_limit(type) \ template <> \ struct Limits { \ static constexpr constant type max = \ metal::numeric_limits::infinity(); \ static constexpr constant type min = \ -metal::numeric_limits::infinity(); \ static constexpr constant type finite_max = \ metal::numeric_limits::max(); \ static constexpr constant type finite_min = \ -metal::numeric_limits::max(); \ }; instantiate_float_limit(half); instantiate_float_limit(float); instantiate_float_limit(bfloat16_t); template <> struct Limits { static constexpr constant bool max = true; static constexpr constant bool min = false; }; template <> struct Limits { static constexpr constant complex64_t max = complex64_t( metal::numeric_limits::infinity(), metal::numeric_limits::infinity()); static constexpr constant complex64_t min = complex64_t( -metal::numeric_limits::infinity(), -metal::numeric_limits::infinity()); }; /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims template METAL_FUNC IdxT elem_to_loc( IdxT elem, constant const int* shape, constant const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Non templated version to handle arbitrary dims template METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, constant const int64_t* strides, int ndim) { IdxT loc = elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); for (int d = ndim - 3; d >= 0; --d) { loc += (elem.z % shape[d]) * IdxT(strides[d]); elem.z /= shape[d]; } return loc; } /////////////////////////////////////////////////////////////////////////////// // Single Array with fixed N dims template METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { return elem * IdxT(stride); } template METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } template METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + elem.z * IdxT(strides[0]); } /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims template METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, int ndim) { vec loc = { IdxT( elem.x * IdxT(a_strides[ndim - 1]) + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), IdxT( elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * IdxT(a_strides[d]); loc.y += l * IdxT(b_strides[d]); elem.z /= shape[d]; } return loc; } template METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, int ndim) { vec loc = { IdxT(elem.x * IdxT(a_strides[ndim - 1])) + IdxT(elem.y * IdxT(a_strides[ndim - 2])), IdxT(elem.x * IdxT(b_strides[ndim - 1])) + IdxT(elem.y * IdxT(b_strides[ndim - 2])), IdxT(elem.x * IdxT(c_strides[ndim - 1])) + IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * IdxT(a_strides[d]); loc.y += l * IdxT(b_strides[d]); loc.z += l * IdxT(c_strides[d]); elem.z /= shape[d]; } return loc; } /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// template struct LoopedElemToLoc { int dim; LoopedElemToLoc inner_looper; OffsetT offset{0}; int index{0}; LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} void next(const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } index++; offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); offset = inner_looper.offset; } } void next(int n, const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } index += n; offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; if (extra >= shape[dim - 1]) { inner_looper.next(1 + extra / shape[dim - 1], shape, strides); extra = extra % shape[dim - 1]; } else { inner_looper.next(shape, strides); } index = 0; offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); } } } OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, OffsetT, true> { int dim; OffsetT offset{0}; uint index{0}; LoopedElemToLoc(int dim) : dim(dim) {} void next(const constant int* shape, const constant int64_t* strides) { index++; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset += OffsetT(strides[0]); } } void next(int n, const constant int* shape, const constant int64_t* strides) { index += n; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset = index * OffsetT(strides[0]); } } OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, OffsetT, false> { OffsetT offset{0}; LoopedElemToLoc(int) {} void next(const constant int*, const constant int64_t* strides) { offset += OffsetT(strides[0]); } void next(int n, const constant int*, const constant int64_t* strides) { offset += n * OffsetT(strides[0]); } OffsetT location() { return offset; } }; /////////////////////////////////////////////////////////////////////////////// // Calculation utils /////////////////////////////////////////////////////////////////////////////// /** Compute ceil((float)N/(float)M) */ template inline T ceildiv(T N, U M) { return (N + M - 1) / M; } // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 inline float log1p(float x) { float xp1 = 1.0f + x; if (xp1 == Limits::max) { return Limits::max; } if (xp1 == 1.0f) { return x; } return x * (metal::log(xp1) / (xp1 - 1.0f)); } inline bfloat16_t log1p(bfloat16_t x) { float xp1 = 1.0f + static_cast(x); if (xp1 == Limits::max) { return Limits::max; } if (xp1 == 1.0f) { return x; } return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } inline complex64_t log1p(complex64_t in) { float x = in.real; float y = in.imag; float zabs = metal::precise::sqrt(x * x + y * y); float theta = metal::atan2(y, x + 1); if (zabs < 0.5f) { float r = x * (2 + x) + y * y; if (r == 0) { // handle underflow return {x, theta}; } return {0.5f * log1p(r), theta}; } else { auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); return {metal::log(z0), theta}; } } /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { return as_type( metal::simd_shuffle_down(as_type(data), delta)); } inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { return as_type( metal::simd_shuffle_down(as_type(data), delta)); } inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); } inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { return complex64_t( simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); } inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { return as_type(metal::simd_shuffle_up(as_type(data), delta)); } inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { return as_type(metal::simd_shuffle_up(as_type(data), delta)); } inline bool simd_shuffle_up(bool data, uint16_t delta) { return simd_shuffle_up(static_cast(data), delta); } inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { return complex64_t( simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); } inline uint64_t simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { return as_type(metal::simd_shuffle_and_fill_up( as_type(data), as_type(filling), delta)); } inline int64_t simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { return as_type(metal::simd_shuffle_and_fill_up( as_type(data), as_type(filling), delta)); } inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { return simd_shuffle_and_fill_up( static_cast(data), static_cast(filling), delta); } inline complex64_t simd_shuffle_and_fill_up( complex64_t data, complex64_t filling, uint16_t delta) { return complex64_t( simd_shuffle_and_fill_up(data.real, filling.real, delta), simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); } inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { return as_type(metal::simd_shuffle(as_type(data), lane)); } inline int64_t simd_shuffle(int64_t data, uint16_t lane) { return as_type(metal::simd_shuffle(as_type(data), lane)); } inline bool simd_shuffle(bool data, uint16_t lane) { return simd_shuffle(static_cast(data), lane); } inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { return complex64_t( simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); } // std::conditional is not included with Metal template struct ConditionalType { using type = U; }; template struct ConditionalType { using type = T; }; ================================================ FILE: mlx/backend/metal/kernels.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/backend/metal/device.h" namespace mlx::core { MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, const array& out); MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op); MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op); MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, const char* op); MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, const char* op); MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out); MTL::ComputePipelineState* get_dynamic_copy_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out); MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, bool precise, const array& out); MTL::ComputePipelineState* get_logsumexp_kernel( metal::Device& d, const std::string& kernel_name, const array& out); MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, bool reverse, bool inclusive, const std::string& reduce_type, const array& in, const array& out); MTL::ComputePipelineState* get_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, int bn, int tn); MTL::ComputePipelineState* get_mb_sort_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& idx, int bn, int tn); MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& out_type); MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, const Dtype& in_type, const Dtype& out_type, const std::string& idx_t, int ndim = -1, int bm = -1, int bn = -1); MTL::ComputePipelineState* get_steel_gemm_fused_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn); MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned); MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( metal::Device& d, const std::string& kernel_name, const array& in, const array& out, bool axbpy); MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned); MTL::ComputePipelineState* get_steel_gemm_gather_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool rhs); MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn); MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, const array& out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter); MTL::ComputePipelineState* get_steel_conv_3d_kernel( metal::Device& d, const std::string& kernel_name, const array& out, int bm, int bn, int bk, int wm, int wn, bool small_filter); MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, const array& out, const std::optional& mask_out, const std::optional& mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous); MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, int bm, int bn, int bk, int wm, int wn); MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string& template_def); MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, const std::string& mode); MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& x, int group_size, int bits, const std::string& mode, int bm, int bn, int bk, int wm, int wn, bool transpose); MTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn); MTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool rhs); MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn); MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, const std::string& mode); MTL::ComputePipelineState* get_gather_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& x, int group_size, int bits, const std::string& mode, int bm, int bn, int bk, int wm, int wn, bool transpose); MTL::ComputePipelineState* get_steel_attention_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& q, int bq, int bk, int bd, int wm, int wn, const array& m); MTL::ComputePipelineState* get_steel_attention_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array& q, int bq, int bk, int bd, int wm, int wn, const array& m); // Create a GPU kernel template definition for JIT compilation template std::string get_template_definition( std::string_view name, std::string_view func, Args... args) { std::ostringstream s; s << func << "<"; bool first = true; auto add_arg = [&s, &first](const auto& arg) { if (!first) { s << ", "; } first = false; s << arg; }; (add_arg(args), ...); s << ">"; return fmt::format( "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", name, s.str()); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/logsumexp.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096; void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[logsumexp] Does not support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); // Make sure that the last dimension is contiguous auto ensure_contiguous = [&s, &d](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } }; auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { out.set_data(allocator::malloc(out.nbytes())); } else { auto n = in.shape(-1); auto flags = in.flags(); auto strides = in.strides(); for (auto& s : strides) { s /= n; } bool col_contig = strides[0] == 1; for (int i = 1; col_contig && i < strides.size(); ++i) { col_contig &= (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); } flags.col_contiguous = col_contig; out.set_data( allocator::malloc(in.nbytes() / n), in.data_size() / n, std::move(strides), flags); } int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; const int simd_size = 32; const int n_reads = 4; const int looped_limit = LOGSUMEXP_LOOPED_LIMIT; std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; kernel_name += "logsumexp_"; kernel_name += type_to_name(out); auto kernel = get_logsumexp_kernel(d, kernel_name, out); auto& compute_encoder = d.get_command_encoder(s.index); { MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(axis_size, 2); compute_encoder.dispatch_threads(grid_dims, group_dims); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/make_compiled_preamble.sh ================================================ #!/bin/bash # # This script generates a C++ function that provides the Metal source code # at runtime for use with kernel generation. # # The steps executed are as follows # - Take as input a metal header file in the mlx metal backend # - Use the metal compiler to expand the dependency headers # - Sort the headers in order of inclusion # - Expand the headers in order of inclusion # - Export the generated source code content as a C++ function # # Doing the expansion this way allows us to retain macros, comments, and # formatting in the expanded source. This adds user readibility, and also # enables use of the metal macros in the source code which can then be # handled by the metal runtime compiler # # Copyright © 2023-25 Apple Inc. OUTPUT_DIR=$1 CC=$2 SRC_DIR=$3 SRC_FILE=$4 CFLAGS=$5 SRC_NAME=$(basename -- "${SRC_FILE}") JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp # Prepare output mkdir -p "$OUTPUT_DIR" # Use the metal compiler to get a list of headers (with depth) CCC="xcrun -sdk macosx metal -x metal" HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null ) # Remove any included system frameworks (for MetalPerformancePrimitive headers) HDRS=$(echo "$HDRS" | grep -v "Xcode") # Use the header depth to sort the files in order of inclusion declare -a HDRS_LIST=($HDRS) declare -a HDRS_STACK=() declare -a HDRS_SORTED=() length=${#HDRS_LIST[@]} HDRS_LIST+=(".") for ((i=0; i<${length}; i+=2)); do header="${HDRS_LIST[$i+1]#$SRC_DIR/}" str_this="${HDRS_LIST[$i]}" str_next="${HDRS_LIST[$i + 2]}" depth_this=${#str_this} depth_next=${#str_next} # If we have a dependency then we stack it if [ $depth_next -gt $depth_this ]; then HDRS_STACK=($header ${HDRS_STACK[@]}) # If we are done with this level else # We add the header to out list HDRS_SORTED+=($header) # Pop the stacked up dependencies pop_len=$((depth_this - depth_next)) for popped_header in "${HDRS_STACK[@]:0:$pop_len}" do HDRS_SORTED+=($popped_header) done HDRS_STACK=(${HDRS_STACK[@]:$pop_len}) fi done # Make sure the given metal header is also expanded in the source content HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}") # Expand the headers in order of inclusion CONTENT=$( echo "// Copyright © 2025 Apple Inc." echo "" echo "// Auto generated source for ${INPUT_FILE#$SRC_DIR/}" echo "" for header in "${HDRS_SORTED[@]}" do echo "///////////////////////////////////////////////////////////////////////////////" echo "// Contents from \"${header}\"" echo "///////////////////////////////////////////////////////////////////////////////" echo "" echo "#line 1 \"${header}\"" grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}" echo "" done echo "///////////////////////////////////////////////////////////////////////////////" ) # Export the generated source code content as a C++ function cat << EOF > "$OUTPUT_FILE" namespace mlx::core::metal { const char* $SRC_NAME() { return R"preamble( $CONTENT )preamble"; } } // namespace mlx::core::metal EOF ================================================ FILE: mlx/backend/metal/matmul.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/matmul.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { std::tuple check_transpose( std::vector& copies, const Stream& s, const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy = contiguous_copy_gpu(arr, s); copies.push_back(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } }; inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } else { return x; } } inline std::tuple ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { if (x.flags().row_contiguous) { return std::make_tuple(false, x.strides()[x.ndim() - 2], x); } bool rc = true; for (int i = 0; i < x.ndim() - 3; i++) { rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i]; } if (rc) { auto stx = x.strides()[x.ndim() - 2]; auto sty = x.strides()[x.ndim() - 1]; auto K = x.shape(-2); auto N = x.shape(-1); if (sty == 1 && (N != 1 || stx == N)) { return std::make_tuple(false, stx, x); } if (stx == 1 && (N != 1 || sty == K)) { return std::make_tuple(true, sty, x); } } array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); } } // namespace /////////////////////////////////////////////////////////////////////////////// // Steel matmul fallback /////////////////////////////////////////////////////////////////////////////// #define GEMM_TPARAM_MACRO(devc) \ if (devc == 'g' || devc == 'p') { /* Small device */ \ if (out.dtype() == complex64) { \ bm = 64; \ bn = 32; \ bk = 8; \ wm = 4; \ wn = 1; \ } else if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else if (out.dtype() != float32) { /* half and bfloat */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } else if (devc == 'd') { /* Large device */ \ if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \ if (out.dtype() != float32) { /* half and bfloat */ \ if (2 * std::max(M, N) > K) { /* Reasonable K */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } else if (!transpose_a && transpose_b) { /* nt with large k */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else { /* nn with large K */ \ bm = 32; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } /* float takes default */ \ } else { /* smaller matmul */ \ if (out.dtype() != float32) { /* half and bfloat */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } else { /* nn */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } \ } else { /* floats */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 32; \ bn = 64; \ bk = 16; \ wm = 1; \ wn = 2; \ } else { /* nn */ \ bm = 64; \ bn = 32; \ bk = 32; \ wm = 2; \ wn = 2; \ } \ } \ } \ } else { /* Medium device */ \ bm = 64; \ bn = 64; \ bk = 16; \ wm = 2; \ wn = 2; \ } /////////////////////////////////////////////////////////////////////////////// // Regular steel matmul dispatch /////////////////////////////////////////////////////////////////////////////// template void steel_matmul_regular_axpby_nax( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel int bm = 128, bn = 128, bk = 512; int wm = 4, wn = 4; // Temp routing for larger devices char devc = d.get_architecture().back(); if (devc == 's' || devc == 'c' || devc == 'd') { bk = (K >= 8192 && K > (M + N)) ? 64 : 256; bm = 64; wm = 2; } // Prepare kernel name std::ostringstream kname; // clang-format off kname << "steel_gemm_fused_nax_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, {&use_out_source, MTL::DataType::DataTypeBool, 100}, {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // clang-format off kname << "_has_batch_" << (has_batch ? 't' : 'n') << "_use_out_source_" << (use_out_source ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_nax_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, /* const std::string& hash_name = */ hash_name, /* const metal::MTLFCList& func_consts = */ func_consts, /* const array& out = */ out, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle int swizzle_log = tm <= 3 ? 0 : 1; if (devc == 's' || devc == 'c' || devc == 'd') { swizzle_log = 2; } // Prepare steel matmul params GEMMParams params{/* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int64_t batch_stride_a = */ A_batch_stride, /* const int64_t batch_stride_b = */ B_batch_stride, /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); if (has_batch) { compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); } if (use_out_source) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; GEMMAddMMParams params{/* const int ldc = */ ldc, /* const int fdc = */ fdc, /* const int64_t batch_stride_c = */ C_batch_stride, /* const float alpha = */ alpha, /* const float beta = */ beta}; compute_encoder.set_input_array(c, 2); compute_encoder.set_bytes(params, 5); } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies d.add_temporaries(std::move(copies), s.index); } template void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && (env::enable_tf32() || a.dtype() != float32)) { return steel_matmul_regular_axpby_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* int ldd = */ ldd, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides batch_strides = */ batch_strides, /* int64_t A_batch_stride = */ A_batch_stride, /* int64_t B_batch_stride = */ B_batch_stride, /* int64_t matrix_stride_out = */ matrix_stride_out, /* int64_t C_batch_stride = */ C_batch_stride, /* float alpha = */ alpha, /* float beta = */ beta); } using namespace mlx::steel; // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) // Prepare kernel name std::ostringstream kname; // clang-format off kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, {&use_out_source, MTL::DataType::DataTypeBool, 100}, {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // clang-format off kname << "_has_batch_" << (has_batch ? 't' : 'n') << "_use_out_source_" << (use_out_source ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, /* const std::string& hash_name = */ hash_name, /* const metal::MTLFCList& func_consts = */ func_consts, /* const array& out = */ out, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); // Prepare steel matmul params GEMMParams params{/* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int64_t batch_stride_a = */ A_batch_stride, /* const int64_t batch_stride_b = */ B_batch_stride, /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); if (has_batch) { compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); } if (use_out_source) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; GEMMAddMMParams params{/* const int ldc = */ ldc, /* const int fdc = */ fdc, /* const int64_t batch_stride_c = */ C_batch_stride, /* const float alpha = */ alpha, /* const float beta = */ beta}; compute_encoder.set_input_array(c, 2); compute_encoder.set_bytes(params, 5); } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // Split k steel matmul /////////////////////////////////////////////////////////////////////////////// template void steel_gemm_splitk_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, float alpha = 1.0f, float beta = 0.0f) { using namespace mlx::steel; int _tm = (M + 32 - 1) / 32; int _tn = (N + 32 - 1) / 32; int _tk = K / 16; int bm = M < 40 ? 16 : 32; int bn = N < 40 ? 16 : 32; int bk = 16; int wm = 2, wn = 2; // As _tk grows use more partitions, as _tm * _tn grow use fewer partitions int split_k_partitions = std::min(std::max(2, next_power_of_2(_tk / (_tm * _tn))), 32); int split_k_partition_stride = M * N; int gemm_k_iterations = (K / bk) / split_k_partitions; int split_k_partition_size = gemm_k_iterations * bk; array C_split( {split_k_partitions, M, N}, issubdtype(out.dtype(), complexfloating) ? complex64 : float32, nullptr, {}); C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; bool k_aligned = K % bk == 0; std::ostringstream kname; // clang-format off kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_splitk_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kname.str(), /* const array& in = */ a, /* const array& out = */ C_split, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn, /* bool mn_aligned = */ mn_aligned, /* bool k_aligned = */ k_aligned); compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; GEMMSpiltKParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldc = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int split_k_partitions = */ split_k_partitions, /* const int split_k_partition_stride = */ split_k_partition_stride, /* const int split_k_partition_size = */ split_k_partition_size, /* const int swizzle_log = */ 0, // no swizzle /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(C_split, 2); compute_encoder.set_bytes(params, 3); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Do accum kernel { const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); if (do_axpby) { kernel_name = kernel_name + "_axbpy"; } auto kernel = get_steel_gemm_splitk_accum_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kernel_name, /* const array& in = */ C_split, /* const array& out = */ out, /* bool axbpy = */ do_axpby); compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel compute_encoder.set_input_array(C_split, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(split_k_partitions, 2); compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(N, 4); if (do_axpby) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; compute_encoder.set_input_array(c, 5); compute_encoder.set_bytes(ldc, 6); compute_encoder.set_bytes(fdc, 7); compute_encoder.set_bytes(alpha, 8); compute_encoder.set_bytes(beta, 9); } // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); auto group_dims = get_block_dims(N, M, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // NAX Split k steel matmul /////////////////////////////////////////////////////////////////////////////// template void steel_gemm_splitk_axpby_nax( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, float alpha = 1.0f, float beta = 0.0f) { using namespace mlx::steel; constexpr int bm = 128, bn = 128, bk = 512; constexpr int wm = 4, wn = 4; // Determine how many partitions to split K into constexpr int split_k_partition_size = 3072; int split_k_partitions = (K + split_k_partition_size - 1) / split_k_partition_size; const int bk_iters_per_partition = split_k_partition_size / bk; const int split_k_partition_stride = M * N; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Per-tile align_K is checked at runtime; only the last tile can be unaligned metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}}; std::ostringstream kname; // clang-format off kname << "steel_gemm_splitk_nax_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); // clang-format off kname << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_splitk_nax_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ base_name, /* const std::string& hash_name = */ hash_name, /* const metal::MTLFCList& func_consts = */ func_consts, /* const array& out = */ C_split, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* int bm = */ bm, /* int bn = */ bn, /* int bk = */ bk, /* int wm = */ wm, /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; int swizzle_log = tm <= 3 ? 0 : 1; // Compute swizzled tile counts int tile = 1 << swizzle_log; int tm_swizzled = (tm + tile - 1) / tile; int tn_swizzled = tn * tile; GEMMSpiltKParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldc = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int split_k_partitions = */ split_k_partitions, /* const int split_k_partition_stride = */ split_k_partition_stride, /* const int split_k_partition_size = */ split_k_partition_size, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ bk_iters_per_partition}; MTL::Size group_dims = MTL::Size(32, wn, wm); // Use 1D grid with K-partition-major layout: [Partition0: M×N // tiles][Partition1: M×N tiles]... Grid size is 1D to prevent driver/HW from // using its own heuristic to exploit 2D locality by launching threadgroups in // a non-linear order MTL::Size grid_dims = MTL::Size(tn_swizzled * tm_swizzled * split_k_partitions, 1, 1); compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(C_split, 2); compute_encoder.set_bytes(params, 3); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Do accum kernel { const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); if (do_axpby) { kernel_name = kernel_name + "_axbpy"; } auto kernel = get_steel_gemm_splitk_accum_kernel( /* metal::Device& d = */ d, /* const std::string& kernel_name = */ kernel_name, /* const array& in = */ C_split, /* const array& out = */ out, /* bool axbpy = */ do_axpby); compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel compute_encoder.set_input_array(C_split, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(split_k_partitions, 2); compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(N, 4); if (do_axpby) { int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; compute_encoder.set_input_array(c, 5); compute_encoder.set_bytes(ldc, 6); compute_encoder.set_bytes(fdc, 7); compute_encoder.set_bytes(alpha, 8); compute_encoder.set_bytes(beta, 9); } // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); auto group_dims = get_block_dims(N, M, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // Split matmul routing /////////////////////////////////////////////////////////////////////////////// template void steel_matmul_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape /* = {} */, Strides A_batch_stride /* = {} */, Strides B_batch_stride /* = {} */, Strides C_batch_stride /* = {} */, float alpha /* = 1.0f */, float beta /* = 0.0f */) { if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions if constexpr (CHECK_AB) { auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = collapse_batches(a, b, c); batch_shape = batch_shape_; A_batch_stride = A_bstride_; B_batch_stride = B_bstride_; C_batch_stride = C_bstride_; // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; C_batch_stride = {0}; batch_shape = {1}; } } else { auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); batch_shape = batch_shape_; A_batch_stride = A_bstride_; B_batch_stride = B_bstride_; // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; batch_shape = {1}; } } } ///////////////////////////////////////////////////////////////////////////// // Split K specialization int _tm = (M + 16 - 1) / 16; int _tn = (N + 16 - 1) / 16; int _tk = K / 16; // Case 1: Small M×N with large K, use SIMD split-K char devc = d.get_architecture().back(); // Max and Ultra dispatch larger sizes to splitk int min_tmn_threshold = (devc == 's' || devc == 'd') ? 2048 : 1024; if (batch_size_out == 1 && (_tm * _tn) <= min_tmn_threshold && _tk >= 8 && K >= std::max(M, N)) { return steel_gemm_splitk_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* float alpha = */ alpha, /* float beta = */ beta); } // Case 2: Large K with sufficient M, N, and NAX is available, use NAX split-K // TODO: Add device-specific tuning for more NAX GPUs in the future constexpr int min_mn_threshold = 2048 * 2048; constexpr int min_k_threshold = 10240; if (batch_size_out == 1 && metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && (env::enable_tf32() || a.dtype() != float32) && int64_t(M) * N >= min_mn_threshold && K >= min_k_threshold && K >= (3 * std::max(M, N))) { return steel_gemm_splitk_axpby_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* float alpha = */ alpha, /* float beta = */ beta); } ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch auto batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); if (CHECK_AB && !C_batch_stride.empty()) { batch_strides.insert( batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); } int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); return steel_matmul_regular_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* int ldd = */ N, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), /* int64_t A_batch_stride = */ A_batch_stride_, /* int64_t B_batch_stride = */ B_batch_stride_, /* int64_t matrix_stride_out = */ int64_t(M) * N, /* int64_t C_batch_stride = */ C_batch_stride_, /* float alpha = */ alpha, /* float beta = */ beta); } /////////////////////////////////////////////////////////////////////////////// // GEMV dispatch /////////////////////////////////////////////////////////////////////////////// template void gemv_axbpy( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}, Strides C_batch_stride = {}, float alpha = 1.0f, float beta = 0.0f) { // Collect problem info bool is_b_matrix = N != 1; auto& mat = is_b_matrix ? b : a; auto& vec = is_b_matrix ? a : b; bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; int mat_ld = is_b_matrix ? ldb : lda; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; // Determine if inputs have simple batching / broadcasting bool contiguous_kernel = (batch_shape.size() == 1); int batch_ndim = batch_shape.size(); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { if (in_vector_len >= 8192 && out_vector_len >= 2048) { sm = 4; sn = 8; } else { sm = 8; sn = 4; } if (out_vector_len >= 2048) { bn = 16; } else if (out_vector_len >= 512) { bn = 4; } else { bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; sn = 32; if (K <= 64) { bm = 1; sm = 8; sn = 4; } else if (K >= 16 * out_vector_len) { bm = 1; bn = 8; } // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv_" << type_to_name(out); } const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); // clang-format off kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" << tm << "_tn" << tn << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; // clang-format on // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(batch_shape, 10); compute_encoder.set_vector_bytes(batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_mat, 12); if (do_axpby) { compute_encoder.set_input_array(c, 2); compute_encoder.set_bytes(alpha, 7); compute_encoder.set_bytes(beta, 8); compute_encoder.set_vector_bytes(C_batch_stride, 13); int bias_stride = c.strides()[c.ndim() - 1]; compute_encoder.set_bytes(bias_stride, 14); } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } inline void gemv( const Stream& s, metal::Device& d, const array& a, const array& b, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}) { return gemv_axbpy( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride); } /////////////////////////////////////////////////////////////////////////////// // Matmul implementation /////////////////////////////////////////////////////////////////////////////// void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), inexact)) { throw std::runtime_error("[matmul] dtype must be inexact."); } auto& s = stream(); auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; // Return 0s if either input is empty if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); auto batch_size_out = out.size() / (size_t(M) * size_t(N)); // Collapse batches into M if needed if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; batch_shape = {1}; } ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { return gemv( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ a_cols, /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides A_batch_stride = */ std::move(A_batch_stride), /* Strides B_batch_stride = */ std::move(B_batch_stride)); } ///////////////////////////////////////////////////////////////////////////// // Gemm specialization return steel_matmul( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ a_cols, /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides A_batch_stride = */ std::move(A_batch_stride), /* Strides B_batch_stride = */ std::move(B_batch_stride)); } /////////////////////////////////////////////////////////////////////////////// // AddMM implementation /////////////////////////////////////////////////////////////////////////////// void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } // Return 0s if either input is empty if (out.size() == 0) { out.set_data(allocator::malloc(out.nbytes())); return; } auto& s = stream(); auto& d = metal::device(s.device); // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { auto& c = inputs[2]; if (beta_ == 1.0f) { copy_gpu( c, out, c.flags().row_contiguous ? CopyType::Vector : CopyType::General, s); } else { array beta_scalar = array(beta_, c.dtype()); binary_op_gpu({c, beta_scalar}, out, "Multiply", s); d.add_temporary(std::move(beta_scalar), s.index); } return; } out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; auto& c_pre = inputs[2]; ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); array c = c_pre; int lda = a_cols; int ldb = b_cols; ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] = collapse_batches(a, b, c); int64_t matrix_stride_out = M * static_cast(N); auto batch_size_out = out.size() / (matrix_stride_out); // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && B_batch_stride.back() == 0) { M *= batch_shape.back(); batch_size_out = 1; A_batch_stride = {0}; B_batch_stride = {0}; C_batch_stride = {0}; batch_shape = {1}; } ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { return gemv_axbpy( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride, /* Strides C_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch return steel_matmul_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride, /* Strides B_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); } /////////////////////////////////////////////////////////////////////////////// // BlockMaskedMM implementation /////////////////////////////////////////////////////////////////////////////// void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; // Return 0s if either input is empty if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); int lda = a_cols; int ldb = b_cols; ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; // Prepare kernel name std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; Shape batch_shape{1}; Strides A_batch_stride{0}; Strides B_batch_stride{0}; Strides outmask_bstride{0}; Strides Amask_bstride{0}; Strides Bmask_bstride{0}; int64_t A_batch_str = 0; int64_t B_batch_str = 0; Strides batch_strides; if (out.ndim() > 2) { Shape bshape{out.shape().begin(), out.shape().end() - 2}; std::vector bstrides; for (auto& arr : inputs) { bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2); } // auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); batch_shape = bshape; A_batch_str = bstrides[0].back(); B_batch_str = bstrides[1].back(); for (auto& bstr : bstrides) { batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end()); } A_batch_stride = bstrides[0]; B_batch_stride = bstrides[1]; if (has_out_mask) { outmask_bstride = bstrides[2]; } if (has_op_mask) { Amask_bstride = bstrides[has_out_mask + 2]; Bmask_bstride = bstrides[has_out_mask + 3]; } } else { batch_strides = Strides(inputs.size(), 0); } int64_t matrix_stride_out = static_cast(M) * N; size_t batch_size_out = out.size() / (matrix_stride_out); ///////////////////////////////////////////////////////////////////////////// // Gemv specialization // Route to gemv if needed if (std::min(M, N) == 1) { // Collect problem info bool is_b_matrix = N != 1; auto& mat = is_b_matrix ? b : a; auto& vec = is_b_matrix ? a : b; bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; int mat_ld = is_b_matrix ? b_cols : a_cols; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride; auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride; auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2); auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3); // Determine if inputs have simple batching / broadcasting bool contiguous_kernel = (batch_shape.size() == 1); int batch_ndim = batch_shape.size(); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { sm = 8; sn = 4; bm = 1; bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2; tm = block_size_ == 32 ? 4 : 8; tn = 4; // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t"; } else { if (block_size_ == 32) { sm = 4; sn = 8; bm = 2; } else { sm = 2; sn = 16; bm = out_vector_len >= 512 ? 4 : 2; } // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv"; } kname << "_outmask_" << out_mask_nm; kname << "_opmask_" << op_mask_nm; kname << "_" << type_to_name(out); kname << "_bm" << bm << "_bn" << bn; kname << "_sm" << sm << "_sn" << sn; kname << "_tm" << tm << "_tn" << tn; kname << "_nc" << !contiguous_kernel; // Encode and dispatch kernel auto kernel = get_gemv_masked_kernel( d, kname.str(), out, has_out_mask ? std::optional{inputs[2]} : std::nullopt, has_op_mask ? std::optional{inputs.back()} : std::nullopt, transpose_mat, bm, bn, sm, sn, tm, tn, contiguous_kernel); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); // Get mask params std::vector mask_strides; Strides mask_batch_strides; if (has_out_mask) { auto& out_mask = inputs[2]; if (transpose_mat) { mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2)); mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1)); } else { mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2)); mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1)); } mask_batch_strides.insert( mask_batch_strides.end(), outmask_bstride.begin(), outmask_bstride.end()); compute_encoder.set_input_array(out_mask, 20); } if (has_op_mask) { auto& mat_mask = inputs[mat_mask_idx]; if (transpose_mat) { mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1)); mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2)); } else { mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1)); mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2)); } mask_batch_strides.insert( mask_batch_strides.end(), mask_bstrides_mat.begin(), mask_bstrides_mat.end()); compute_encoder.set_input_array(mat_mask, 21); auto& vec_mask = inputs[vec_mask_idx]; if (transpose_mat) { mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2)); mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1)); } else { mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2)); mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1)); } mask_batch_strides.insert( mask_batch_strides.end(), mask_bstrides_vec.begin(), mask_bstrides_vec.end()); compute_encoder.set_input_array(vec_mask, 22); } // Get gemv params compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(batch_shape, 10); compute_encoder.set_vector_bytes(batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_mat, 12); compute_encoder.set_vector_bytes(mask_strides, 23); compute_encoder.set_vector_bytes(mask_batch_strides, 24); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; } ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch // Determine dispatch kernel int bm = block_size_, bn = block_size_, bk = 16; int wm = 2, wn = 2; bool mn_aligned = M % bm == 0 && N % bn == 0; bool k_aligned = K % bk == 0; std::ostringstream kname; kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_" << op_mask_nm << "_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_masked_kernel( d, kname.str(), out, has_out_mask ? std::optional{inputs[2]} : std::nullopt, has_op_mask ? std::optional{inputs.back()} : std::nullopt, transpose_a, transpose_b, bm, bn, bk, wm, wn, mn_aligned, k_aligned); compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); // Prepare steel matmul params GEMMParams params{/* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ ldb, /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, /* const int64_t batch_stride_a = */ A_batch_str, /* const int64_t batch_stride_b = */ B_batch_str, /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); std::vector mask_strides; if (has_out_mask) { auto& out_mask = inputs[2]; mask_strides.push_back(*(out_mask.strides().end() - 1)); mask_strides.push_back(*(out_mask.strides().end() - 2)); compute_encoder.set_input_array(out_mask, 10); } if (has_op_mask) { auto& lhs_mask = inputs[2 + has_out_mask]; mask_strides.push_back(*(lhs_mask.strides().end() - 1)); mask_strides.push_back(*(lhs_mask.strides().end() - 2)); compute_encoder.set_input_array(lhs_mask, 11); auto& rhs_mask = inputs[3 + has_out_mask]; mask_strides.push_back(*(rhs_mask.strides().end() - 1)); mask_strides.push_back(*(rhs_mask.strides().end() - 2)); compute_encoder.set_input_array(rhs_mask, 12); } // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(mask_strides, 13); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } /////////////////////////////////////////////////////////////////////////////// // GatherMM implementation /////////////////////////////////////////////////////////////////////////////// void gather_mm_rhs( const array& a_, const array& b_, const array& indices_, array& out, metal::Device& d, const Stream& s) { array indices = ensure_row_contiguous(indices_, d, s); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); // Broadcast a with indices. If we are here that means lhs_indices were not // provided so the lhs_indices are implied to be the shape of a broadcasted // with rhs_indices. We need only broadcast a and copy it as if applying the // lhs_indices. auto broadcast_with_indices = [&d, &s, &indices](const array& x) { if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { return ensure_row_contiguous(x, d, s); } auto x_shape = indices.shape(); x_shape.push_back(x.shape(-2)); x_shape.push_back(x.shape(-1)); array new_x(std::move(x_shape), x.dtype(), nullptr, {}); broadcast(x, new_x); return ensure_row_contiguous(new_x, d, s); }; array a = broadcast_with_indices(a_); // Extract the matmul shapes int K = a.shape(-1); int M = a.size() / K; int N = b.shape(-1); int lda = a.strides()[a.ndim() - 2]; // should be K // Define the dispatch blocks int bm = 16, bn = 64, bk = 16; int wm = 1, wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Define the kernel name std::string base_name; base_name.reserve(64); concatenate( base_name, "steel_gather_mm_rhs_n", transpose_b ? 't' : 'n', '_', type_to_name(a), '_', type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, func_consts, out, false, transpose_b, bm, bn, bk, wm, wn, true); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ 0, /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), /* const int64_t batch_stride_d = */ 0, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ 0}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(indices, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_mm_rhs_nax( const array& a_, const array& b_, const array& indices_, array& out, metal::Device& d, const Stream& s) { array indices = ensure_row_contiguous(indices_, d, s); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); // Broadcast a with indices. If we are here that means lhs_indices were not // provided so the lhs_indices are implied to be the shape of a broadcasted // with rhs_indices. We need only broadcast a and copy it as if applying the // lhs_indices. auto broadcast_with_indices = [&d, &s, &indices](const array& x) { if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { return ensure_row_contiguous(x, d, s); } auto x_shape = indices.shape(); x_shape.push_back(x.shape(-2)); x_shape.push_back(x.shape(-1)); array new_x(std::move(x_shape), x.dtype(), nullptr, {}); broadcast(x, new_x); return ensure_row_contiguous(new_x, d, s); }; array a = broadcast_with_indices(a_); // Extract the matmul shapes int K = a.shape(-1); int M = a.size() / K; int N = b.shape(-1); int lda = a.strides()[a.ndim() - 2]; // should be K int E = b.shape(0); // Define the dispatch blocks int bm, bn = 128, bk = 128, wm, wn = 4; if (M / E > 48) { bm = 64; wm = 2; } else if (M / E > 24) { bm = 32l; wm = 1; } else { bm = 16; wm = 1; } const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Define the kernel name std::string base_name; base_name.reserve(64); concatenate( base_name, "steel_gather_mm_rhs_nax_n", transpose_b ? 't' : 'n', '_', type_to_name(a), '_', type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_nax_kernel( d, base_name, hash_name, func_consts, out, false, transpose_b, bm, bn, bk, wm, wn, true); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ lda, /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ 0, /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), /* const int64_t batch_stride_d = */ 0, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ 0}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(indices, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_mv( const array& mat_, const array& vec_, const array& mat_indices_, const array& vec_indices_, array& out, int N, int K, bool is_mv, metal::Device& d, const Stream& s) { // Copy if needed std::vector copies; auto [transpose_mat, mat_cols, mat] = check_transpose(copies, s, mat_, N == 1); auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); d.add_temporaries(std::move(copies), s.index); // If we are doing vector matrix instead of matrix vector we need to flip the // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated // as a one dimensional array. transpose_mat = (!is_mv) ^ transpose_mat; // Define some shapes int in_vector_len = K; int out_vector_len = N; int mat_ld = mat_cols; int batch_size_out = out.size() / N; int batch_ndim = out.ndim() - 2; int batch_ndim_mat = mat.ndim() - 2; int batch_ndim_vec = vec.ndim() - 2; Strides index_strides = vec_indices_.strides(); index_strides.insert( index_strides.end(), mat_indices_.strides().begin(), mat_indices_.strides().end()); // Determine dispatch kernel int tm = 4, tn = 4; int sm = 1, sn = 32; int bm = 1, bn = 1; int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { if (in_vector_len >= 8192 && out_vector_len >= 2048) { sm = 4; sn = 8; } else { sm = 8; sn = 4; } if (out_vector_len >= 2048) { bn = 16; } else if (out_vector_len >= 512) { bn = 4; } else { bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; n_out_per_tgp = bn * sn * tn; kname << "gemv_t_gather_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; n_out_per_tgp = bm * sm * tm; kname << "gemv_gather_" << type_to_name(out); } kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" << tm << "_tn" << tn; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(in_vector_len, 4); compute_encoder.set_bytes(out_vector_len, 5); compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_vector_bytes(out.shape(), 10); compute_encoder.set_vector_bytes(index_strides, 11); compute_encoder.set_bytes(batch_ndim_vec, 12); compute_encoder.set_vector_bytes(vec.shape(), 13); compute_encoder.set_vector_bytes(vec.strides(), 14); compute_encoder.set_bytes(batch_ndim_mat, 15); compute_encoder.set_vector_bytes(mat.shape(), 16); compute_encoder.set_vector_bytes(mat.strides(), 17); compute_encoder.set_input_array(vec_indices_, 18); compute_encoder.set_input_array(mat_indices_, 19); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_mm( const array& a_, const array& b_, const array& lhs_indices, const array& rhs_indices, array& out, int M, int N, int K, metal::Device& d, const Stream& s) { // Copy if needed std::vector copies; auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; size_t batch_size_out = out.size() / M / N; int batch_ndim = out.ndim() - 2; int batch_ndim_a = a.ndim() - 2; int batch_ndim_b = b.ndim() - 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) const bool has_batch = batch_ndim > 1; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Define the kernel name std::string base_name; base_name.reserve(128); concatenate( base_name, "steel_gather_mm_", transpose_a ? 't' : 'n', transpose_b ? 't' : 'n', "_", type_to_name(a), "_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_has_batch_", has_batch ? 't' : 'n', "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, func_consts, out, transpose_a, transpose_b, bm, bn, bk, wm, wn, false); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params steel::GEMMParams params{/* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ static_cast(lda), /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ (batch_ndim > 0) ? lhs_indices.strides()[0] : 0, /* const int64_t batch_stride_b = */ (batch_ndim > 0) ? rhs_indices.strides()[0] : 0, /* const int64_t batch_stride_d = */ M * N, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(lhs_indices, 2); compute_encoder.set_input_array(rhs_indices, 3); compute_encoder.set_output_array(out, 4); compute_encoder.set_bytes(params, 5); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7); compute_encoder.set_vector_bytes(rhs_indices.strides(), 8); compute_encoder.set_bytes(batch_ndim_a, 9); compute_encoder.set_vector_bytes(a.shape(), 10); compute_encoder.set_vector_bytes(a.strides(), 11); compute_encoder.set_bytes(batch_ndim_b, 12); compute_encoder.set_vector_bytes(b.shape(), 13); compute_encoder.set_vector_bytes(b.strides(), 14); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void GatherMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto& a = inputs[0]; auto& b = inputs[1]; auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; // Return 0s if either input is empty if (a.size() == 0 || b.size() == 0) { array zero = array(0, a.dtype()); fill_gpu(zero, out, s); d.add_temporary(std::move(zero), s.index); return; } out.set_data(allocator::malloc(out.nbytes())); // Extract shapes from inputs. int M = a.shape(-2); int N = b.shape(-1); int K = a.shape(-1); // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { if (metal::is_nax_available() && (env::enable_tf32() || a.dtype() != float32)) { return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); } gather_mm_rhs(a, b, rhs_indices, out, d, s); return; } // Route to gather gemv if any of a or b are vectors if (M == 1) { gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s); return; } if (N == 1) { gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s); return; } // Route to non specialized gather mm gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } void segmented_mm( const array& a_, const array& b_, const array& segments_, array& out, int M, int N, int K, metal::Device& d, const Stream& s) { auto check_segments_layout = [&d, &s](const array& x) { // Contiguous so return early if (x.flags().row_contiguous) { return std::make_tuple(true, x); } bool rc = true; for (int i = 0; i < x.ndim() - 2; i++) { rc &= (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1); } rc &= x.strides(x.ndim() - 1) == 1; if (x.ndim() > 1) { rc &= x.strides(x.ndim() - 2) == 1; } if (rc) { return std::make_tuple(false, x); } array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return std::make_tuple(true, x_copy); }; // Copy if needed std::vector copies; auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); auto [segments_contiguous, segments] = check_segments_layout(segments_); d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; size_t batch_size_out = out.size() / M / N; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; // Define the kernel name std::string base_name; base_name.reserve(128); concatenate( base_name, "steel_segmented_mm_", transpose_a ? 't' : 'n', transpose_b ? 't' : 'n', "_", type_to_name(a), "_", type_to_name(out), "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn); metal::MTLFCList func_consts = { {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, base_name, "_segments_contiguous_", segments_contiguous ? 't' : 'n', "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_segmented_kernel( d, base_name, hash_name, func_consts, out, transpose_a, transpose_b, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params steel::GEMMParams params{/* const int M = */ M, /* const int N = */ N, /* const int K = */ K, /* const int lda = */ static_cast(lda), /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, /* const int tiles_n = */ (N + bn - 1) / bn, /* const int tiles_m = */ (M + bm - 1) / bm, /* const int64_t batch_stride_a = */ 0, /* const int64_t batch_stride_b = */ 0, /* const int64_t batch_stride_d = */ M * N, /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ 0, /* const int batch_ndim = */ 0}; // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(segments, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(params, 4); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto& a = inputs[0]; auto& b = inputs[1]; auto& segments = inputs[2]; out.set_data(allocator::malloc(out.nbytes())); // Extract shapes from inputs. int M = a.shape(-2); int N = b.shape(-1); int K = a.shape(-1); segmented_mm(a, b, segments, out, M, N, K, d, s); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/matmul.h ================================================ // Copyright © 2023 Apple Inc. #pragma once #include "mlx/backend/metal/device.h" namespace mlx::core { template void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, int64_t C_batch_stride = 0, float alpha = 1.0f, float beta = 0.0f); inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, const array& b, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out) { return steel_matmul_regular_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* int ldd = */ ldd, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides batch_strides = */ batch_strides, /* int64_t A_batch_stride = */ A_batch_stride, /* int64_t B_batch_stride = */ B_batch_stride, /* int64_t matrix_stride_out = */ matrix_stride_out); } template void steel_matmul_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, const array& c, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}, Strides C_batch_stride = {}, float alpha = 1.0f, float beta = 0.0f); inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, const array& b, array& out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, Strides B_batch_stride = {}) { return steel_matmul_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, /* const array& c = */ b, /* array& out = */ out, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, /* Shape batch_shape = */ batch_shape, /* Strides A_batch_stride = */ A_batch_stride, /* Strides B_batch_stride = */ B_batch_stride); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/metal.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" namespace mlx::core::metal { bool is_available() { return true; } void start_capture(std::string path, NS::Object* object) { auto pool = new_scoped_memory_pool(); auto descriptor = MTL::CaptureDescriptor::alloc()->init(); descriptor->setCaptureObject(object); if (!path.empty()) { auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding); auto url = NS::URL::fileURLWithPath(string); descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument); descriptor->setOutputURL(url); } auto manager = MTL::CaptureManager::sharedCaptureManager(); NS::Error* error; bool started = manager->startCapture(descriptor, &error); descriptor->release(); if (!started) { std::ostringstream msg; msg << "[metal::start_capture] Failed to start: " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } } void start_capture(std::string path) { auto& device = metal::device(mlx::core::Device::gpu); return start_capture(path, device.mtl_device()); } void stop_capture() { auto pool = new_scoped_memory_pool(); auto manager = MTL::CaptureManager::sharedCaptureManager(); manager->stopCapture(); } } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/metal.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include "mlx/api.h" namespace mlx::core::metal { /* Check if the Metal backend is available. */ MLX_API bool is_available(); /** Capture a GPU trace, saving it to an absolute file `path` */ MLX_API void start_capture(std::string path = ""); MLX_API void stop_capture(); /** Get information about the GPU and system settings. */ MLX_API const std::unordered_map>& device_info(); } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/no_metal.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/metal/metal.h" #include "mlx/fast.h" namespace mlx::core { namespace metal { bool is_available() { return false; } void start_capture(std::string) {} void stop_capture() {} const std::unordered_map>& device_info() { throw std::runtime_error( "[metal::device_info] Cannot get device info without metal backend"); }; } // namespace metal namespace fast { CustomKernelFunction metal_kernel( const std::string&, const std::vector&, const std::vector&, const std::string&, const std::string&, bool, bool) { throw std::runtime_error("[metal_kernel] No Metal back-end."); } } // namespace fast } // namespace mlx::core ================================================ FILE: mlx/backend/metal/nojit_kernels.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, Dtype, const char*) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, Dtype, const char*) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype, Dtype, const char*) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, const char*) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_dynamic_copy_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, bool, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_logsumexp_kernel( metal::Device& d, const std::string& kernel_name, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, bool, bool, const std::string&, const array&, const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_sort_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&, int, int) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_mb_sort_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&, int, int) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, const std::string&, const Dtype&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, const std::string&, const Dtype&, const Dtype&, const std::string&, int, int, int) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_gemm_fused_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&, bool, bool, int, int, int, int, int, bool, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( metal::Device& d, const std::string& kernel_name, const array&, const array&, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array&, const std::optional& mask_out, const std::optional& mask_op, bool, bool, int, int, int, int, int, bool, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_gemm_gather_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int, bool) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, const array&, const std::optional&, const std::optional&, bool, int, int, int, int, int, int, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, const array&, int, int, int, int, int, int, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_conv_3d_kernel( metal::Device& d, const std::string& kernel_name, const array&, int, int, int, int, int, bool) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, int, int, int, int, int) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string&) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, const std::string&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, int, int, const std::string&, int, int, int, int, int, bool) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int, bool) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, bool, bool, int, int, int, int, int) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, const std::string&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_gather_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, int, int, const std::string&, int, int, int, int, int, bool) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_attention_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, int, int, int, int, int, const array&) { return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_attention_nax_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, int, int, int, int, int, const array&) { return d.get_kernel(kernel_name, hash_name, func_consts); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/normalization.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" namespace mlx::core::fast { bool RMSNorm::use_fallback(Stream s) { return s.device == Device::cpu; } void RMSNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); auto& out = outputs[0]; // Make sure that the last dimension is contiguous auto set_output = [&s, &out](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; const array x = set_output(inputs[0]); const array& w = inputs[1]; auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; const int simd_size = 32; const int n_reads = RMS_N_READS; const int looped_limit = RMS_LOOPED_LIMIT; std::string op_name = "rms"; if (axis_size > looped_limit) { op_name += "_looped"; } op_name += type_to_name(out); auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_output_array(out, 2); compute_encoder.set_bytes(eps_, 3); compute_encoder.set_bytes(axis_size, 4); compute_encoder.set_bytes(w_stride, 5); compute_encoder.dispatch_threads(grid_dims, group_dims); } } void RMSNormVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. auto check_input = [&s](const array& x) -> std::pair { if (x.flags().row_contiguous) { return {x, false}; } array x_copy = contiguous_copy_gpu(x, s); return {x_copy, true}; }; bool donate_g = inputs[2].is_donatable(); auto [x, copied] = check_input(inputs[0]); const array& w = inputs[1]; auto [g, g_copied] = check_input(inputs[2]); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; // Check whether we had a weight bool has_w = w.ndim() != 0; // Allocate space for the outputs bool g_in_gx = false; if (x.is_donatable()) { gx.copy_shared_buffer(x); } else if (g.is_donatable()) { gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); } auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; // Allocate the gradient accumulator gw and a temporary to store the // gradients before they are accumulated. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; if (has_w) { if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } gw.set_data(allocator::malloc(gw.nbytes())); const int simd_size = 32; const int n_reads = RMS_N_READS; const int looped_limit = RMS_LOOPED_LIMIT; std::string op_name = "vjp_rms"; if (axis_size > looped_limit) { op_name += "_looped"; } op_name += type_to_name(gx); std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); metal::MTLFCList func_consts = { {&has_w, MTL::DataType::DataTypeBool, 20}, }; auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(g, 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_bytes(eps_, 5); compute_encoder.set_bytes(axis_size, 6); compute_encoder.set_bytes(w_stride, 7); compute_encoder.dispatch_threads(grid_dims, group_dims); } if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); } } bool LayerNorm::use_fallback(Stream s) { return s.device == Device::cpu; } void LayerNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); auto& out = outputs[0]; // Make sure that the last dimension is contiguous auto set_output = [&s, &out](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; int simd_size = 32; int n_reads = 8; int looped_limit = 6656; std::string op_name = "layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; n_reads = 4; } op_name += type_to_name(out); auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { std::ostringstream msg; msg << "[layer_norm] Threadgroup size " << threadgroup_size << " is larger than the maximum allowed threadgroup size " << kernel->maxTotalThreadsPerThreadgroup(); throw std::runtime_error(msg.str()); } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(b, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(eps_, 4); compute_encoder.set_bytes(axis_size, 5); compute_encoder.set_bytes(w_stride, 6); compute_encoder.set_bytes(b_stride, 7); compute_encoder.dispatch_threads(grid_dims, group_dims); } } void LayerNormVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. auto check_input = [&s](const array& x) -> std::pair { if (x.flags().row_contiguous) { return {x, false}; } array x_copy = contiguous_copy_gpu(x, s); return {x_copy, true}; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); auto [x, copied] = check_input(inputs[0]); donate_x |= copied; const array& w = inputs[1]; auto [g, g_copied] = check_input(inputs[3]); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; array& gb = outputs[2]; // Check whether we had a weight bool has_w = w.ndim() != 0; // Allocate space for the outputs bool g_in_gx = false; if (donate_x) { gx.copy_shared_buffer(x); } else if (donate_g) { gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); } auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; // Allocate a temporary to store the gradients for w and allocate the output // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; if (has_w) { if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } gw.set_data(allocator::malloc(gw.nbytes())); gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b auto& compute_encoder = d.get_command_encoder(s.index); if (gb.ndim() == 1 && gb.size() == axis_size) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( g, gb, "sum", plan, {0}, compute_encoder, d, s); } int simd_size = 32; int n_reads = 8; int looped_limit = 8192; std::string op_name = "vjp_layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; n_reads = 4; } op_name += type_to_name(gx); std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); metal::MTLFCList func_consts = { {&has_w, MTL::DataType::DataTypeBool, 20}, }; { auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { std::ostringstream msg; msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size << " is larger than the maximum allowed threadgroup size " << kernel->maxTotalThreadsPerThreadgroup(); throw std::runtime_error(msg.str()); } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(g, 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_bytes(eps_, 5); compute_encoder.set_bytes(axis_size, 6); compute_encoder.set_bytes(w_stride, 7); compute_encoder.dispatch_threads(grid_dims, group_dims); } if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); } } } // namespace mlx::core::fast ================================================ FILE: mlx/backend/metal/primitives.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core { template void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(start, 0); T step = next - start; enc.set_bytes(step, 1); } void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } auto& s = stream(); auto& d = metal::device(s.device); auto kernel = get_arange_kernel(d, "arange" + type_to_name(out), out); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size group_dims = MTL::Size( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); switch (out.dtype()) { case bool_: // unsupported throw std::runtime_error("[Arange::eval_gpu] Does not support bool"); case uint8: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case uint16: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case uint32: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case uint64: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case int8: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case int16: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case int32: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case int64: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case float16: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case float32: arange_set_scalars(start_, start_ + step_, compute_encoder); break; case bfloat16: arange_set_scalars(start_, start_ + step_, compute_encoder); break; default: throw std::runtime_error("[Arange::eval_gpu] Does not support type."); } compute_encoder.set_output_array(out, 2); compute_encoder.dispatch_threads(grid_dims, group_dims); } void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); std::string op_name; switch (reduce_type_) { case ArgReduce::ArgMin: op_name = "argmin_"; break; case ArgReduce::ArgMax: op_name = "argmax_"; break; } // Prepare the shapes, strides and axis arguments. auto in_strides = in.strides(); auto shape = in.shape(); auto out_strides = out.strides(); auto axis_stride = in_strides[axis_]; size_t axis_size = shape[axis_]; if (out_strides.size() == in_strides.size()) { out_strides.erase(out_strides.begin() + axis_); } in_strides.erase(in_strides.begin() + axis_); shape.erase(shape.begin() + axis_); size_t ndim = shape.size(); // ArgReduce int simd_size = 32; int n_reads = 4; auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name + type_to_name(in)); NS::UInteger thread_group_size = std::min( (axis_size + n_reads - 1) / n_reads, kernel->maxTotalThreadsPerThreadgroup()); // round up to the closest number divisible by simd_size thread_group_size = (thread_group_size + simd_size - 1) / simd_size * simd_size; assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); auto gd = get_2d_grid_dims(out.shape(), out.strides()); MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); if (ndim == 0) { // Pass place holders so metal doesn't complain int shape_ = 0; int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 2); compute_encoder.set_bytes(stride_, 3); compute_encoder.set_bytes(stride_, 4); } else { compute_encoder.set_vector_bytes(shape, 2); compute_encoder.set_vector_bytes(in_strides, 3); compute_encoder.set_vector_bytes(out_strides, 4); } compute_encoder.set_bytes(ndim, 5); compute_encoder.set_bytes(axis_stride, 6); compute_encoder.set_bytes(axis_size, 7); compute_encoder.dispatch_threads(grid_dims, group_dims); } } void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; size_t num_keys = keys.size() / 2; size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } size_t out_per_key = (bytes_per_key + 4 - 1) / 4; size_t half_size = out_per_key / 2; bool odd = out_per_key % 2; auto& s = stream(); auto& d = metal::device(s.device); std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits"; auto kernel = d.get_kernel(kname); // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); auto group_dims = get_block_dims(num_keys, half_size + odd, 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(keys, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(odd, 2); compute_encoder.set_bytes(bytes_per_key, 3); if (!keys.flags().row_contiguous) { int ndim = keys.ndim(); compute_encoder.set_bytes(ndim, 4); compute_encoder.set_vector_bytes(keys.shape(), 5); compute_encoder.set_vector_bytes(keys.strides(), 6); } compute_encoder.dispatch_threads(grid_dims, group_dims); } void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); } void SVD::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); } void Inverse::eval_gpu(const std::vector& inputs, array& output) { throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); } void Cholesky::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } void Eig::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); } void Eigh::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/quantized.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { template auto get_quantized_kernel_wrapped( metal::Device& d, const std::string& name, const std::string& func, const std::string& mode, const std::string& type, int group_size, int bits, Args... args) { std::string template_def; std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; template_def = get_template_definition( name, fname, type, group_size, bits, std::forward(args)...); return get_quantized_kernel(d, name, template_def, mode); } template auto get_qmm_nax_kernel_wrapped( metal::Device& d, const std::string& name, const std::string& func, const std::string& mode, const std::string& type, int group_size, int bits, Args... args) { std::string template_def; std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; template_def = get_template_definition( name, fname, type, group_size, bits, std::forward(args)...); return get_qmm_nax_kernel(d, name, template_def, mode); } inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } else { return x; } } inline array ensure_row_contiguous_matrix( const array& x, metal::Device& d, const Stream& s) { if (x.ndim() < 2) { if (x.strides()[0] == 1) { return x; } } else { auto stride_0 = x.strides()[x.ndim() - 2]; auto stride_1 = x.strides()[x.ndim() - 1]; if (stride_0 == x.shape(-1) && stride_1 == 1) { return x; } } array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { auto arch_size = d.get_architecture().back(); auto arch_gen = d.get_architecture_gen(); if (arch_gen == 13 || arch_gen == 14) { switch (arch_size) { case 'd': if (D <= 2048 && O <= 2048) { return 32; } else if (D <= 4096 && O <= 4096) { return 18; } else { return 12; } default: if (D <= 2048 && O <= 2048) { return 14; } else if (D <= 4096 && O <= 4096) { return 10; } else { return 6; } } } else { switch (arch_size) { case 'd': if (D <= 2048 && O <= 2048) { return 32; } else if (D <= 4096 && O <= 4096) { return 18; } else { return 12; } default: if (D <= 2048 && O <= 2048) { return 18; } else if (D <= 4096 && O <= 4096) { return 12; } else { return 10; } } } } inline int add_strides_and_shapes( CommandEncoder& compute_encoder, bool skip, const array& x, const array& w, const array& scales, const std::optional& biases, int offset) { if (skip) { return 0; } // TODO: Collapse batch dimensions int x_batch_ndims = x.ndim() - 2; int w_batch_ndims = w.ndim() - 2; compute_encoder.set_bytes(x_batch_ndims, offset++); compute_encoder.set_vector_bytes(x.shape(), offset++); compute_encoder.set_vector_bytes(x.strides(), offset++); compute_encoder.set_bytes(w_batch_ndims, offset++); compute_encoder.set_vector_bytes(w.shape(), offset++); compute_encoder.set_vector_bytes(w.strides(), offset++); compute_encoder.set_vector_bytes(scales.strides(), offset++); if (biases) { compute_encoder.set_vector_bytes(biases->strides(), offset++); } return offset; } inline int add_gather_strides_and_shapes( CommandEncoder& compute_encoder, const array& lhs_indices, const array& rhs_indices, int offset) { auto [shape, strides] = collapse_contiguous_dims( lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); int ndims = shape.size(); compute_encoder.set_bytes(ndims, offset++); compute_encoder.set_vector_bytes(shape, offset++); compute_encoder.set_vector_bytes(strides[0], offset++); compute_encoder.set_vector_bytes(strides[1], offset++); return offset; } } // namespace void qmv_quad( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; constexpr int quads_per_simd = 8; constexpr int results_per_quadgroup = 8; int bn = quads_per_simd * results_per_quadgroup; int simdgroup_size = 32; MTL::Size group_dims(simdgroup_size, 1, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + "_qmv_quad_", type_string, "_gs_", group_size, "_b_", bits, "_d_", K, B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void qmv( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; int bn = 8; int bk = 32; MTL::Size group_dims(bk, 2, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); bool fast = N % bn == 0 && K % 512 == 0; concatenate( kname, mode + (fast ? "_qmv_fast_" : "_qmv_"), type_string, "_gs_", group_size, "_b_", bits, B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( d, kname, (fast ? "qmv_fast" : "qmv"), mode, type_string, group_size, bits, B > 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void qvm_split_k( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int split_k = K > 8192 ? 32 : 8; int split_D = (K + split_k - 1) / split_k; int B = out.size() / M / N; B *= split_k; constexpr int num_simdgroups = 2; constexpr int bk = 32; int bn = std::min(group_size, 32) * num_simdgroups; MTL::Size group_dims = MTL::Size(bk, num_simdgroups, 1); MTL::Size grid_dims = MTL::Size(M, N / bn, B); auto x_shape = x.shape(); auto x_strides = x.strides(); if (x_shape.size() == 1) { x_shape.insert(x_shape.begin(), 1); x_strides.insert(x_strides.begin(), 0); } int x_ndim = x_shape.size(); int x_batch_ndims = x_ndim - 2; int w_batch_ndims = w.ndim() - 2; auto w_shape = w.shape(); auto w_strides = w.strides(); auto s_strides = scales.strides(); // Add split_k dim with reshapes x_shape.insert(x_shape.end() - 2, split_k); x_shape.back() /= split_k; x_strides.insert(x_strides.end() - 2, split_D); x_strides[x_ndim - 1] = split_D; x_batch_ndims += 1; w_shape.insert(w_shape.end() - 2, split_k); w_shape[w.ndim() - 1] /= split_k; w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1)); w_batch_ndims += 1; s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); int final_block_size = K - (split_k - 1) * split_D; auto temp_shape = out.shape(); if (temp_shape.size() == 1) { temp_shape.insert(temp_shape.begin(), 1); } temp_shape.insert(temp_shape.end() - 2, split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); std::string type_string = get_type_string(x.dtype()); std::string kname; kname.reserve(64); concatenate( kname, mode + "_qvm_split_k_", type_string, "_gs_", group_size, "_b_", bits, "_spk_", split_k); // Encode and dispatch kernel auto kernel = get_quantized_kernel_wrapped( d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(intermediate, c++); compute_encoder.set_bytes(split_D, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(x_batch_ndims, c++); compute_encoder.set_vector_bytes(x_shape, c++); compute_encoder.set_vector_bytes(x_strides, c++); compute_encoder.set_bytes(w_batch_ndims, c++); compute_encoder.set_vector_bytes(w_shape, c++); compute_encoder.set_vector_bytes(w_strides, c++); compute_encoder.set_vector_bytes(s_strides, c++); if (biases) { auto b_strides = biases->strides(); b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1)); compute_encoder.set_vector_bytes(b_strides, c++); } compute_encoder.set_bytes(final_block_size, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); int axis = intermediate.ndim() - 3; ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {intermediate.shape(axis)}, {intermediate.strides(axis)}); strided_reduce_general_dispatch( intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); } void qvm( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; constexpr int num_simdgroups = 2; constexpr int bk = 32; int bn = std::min(group_size, 32) * num_simdgroups; MTL::Size group_dims(bk, num_simdgroups, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + "_qvm_", type_string, "_gs_", group_size, "_b_", bits, B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( d, kname, "qvm", mode, type_string, group_size, bits, B > 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void qmm_nax( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; int wm = 2; int wn = 2; int bm = 64; int bn = 64; int bk = 64; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); bool aligned = N % 64 == 0; bool batched = B > 1; std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"), type_string, "_gs_", group_size, "_b_", bits, "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn, transpose ? (aligned ? "_alN_true" : "_alN_false") : "", batched ? "_batch_1" : "_batch_0"); std::string template_def; MTL::ComputePipelineState* kernel; if (transpose) { kernel = get_qmm_nax_kernel_wrapped( d, kname, "qmm_t_nax", mode, type_string, group_size, bits, aligned, batched, bm, bk, bn, wm, wn); } else { kernel = get_qmm_nax_kernel_wrapped( d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched, bm, bk, bn, wm, wn); } auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(M, c++); add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qmm_nax( const array& x, const array& w, const array& scales, const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; int wm = 2; int wn = 2; int bm = 64; int bn = 64; int bk = 32; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); bool aligned = N % 64 == 0; std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"), type_string, "_gs_", group_size, "_b_", bits, "_bm", bm, "_bn", bn, "_bk", bk, "_wm", wm, "_wn", wn, transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); MTL::ComputePipelineState* kernel; if (transpose) { kernel = get_qmm_nax_kernel_wrapped( d, kname, "gather_qmm_t_nax_", mode, type_string, group_size, bits, aligned, bm, bk, bn, wm, wn); } else { kernel = get_qmm_nax_kernel_wrapped( d, kname, "gather_qmm_n_nax_", mode, type_string, group_size, bits, bm, bk, bn, wm, wn); } auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(M, c++); c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void qmm( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && (env::enable_tf32() || x.dtype() != float32)) { return qmm_nax( /* const array& x = */ x, /* const array& w = */ w, /* const array& scales = */ scales, /* const std::optional& biases = */ biases, /* array& out = */ out, /* bool transpose = */ transpose, /* int group_size = */ group_size, /* int bits = */ bits, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, /* const std::string& mode = */ mode); } int B = out.size() / M / N; int wm = 2; int wn = 2; int bm = 32; int bn = 32; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); bool aligned = N % 32 == 0; bool batched = B > 1; std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_qmm_t_" : "_qmm_n_"), type_string, "_gs_", group_size, "_b_", bits, transpose ? (aligned ? "_alN_true" : "_alN_false") : "", batched ? "_batch_1" : "_batch_0"); std::string template_def; MTL::ComputePipelineState* kernel; if (transpose) { kernel = get_quantized_kernel_wrapped( d, kname, "qmm_t", mode, type_string, group_size, bits, aligned, batched); } else { kernel = get_quantized_kernel_wrapped( d, kname, "qmm_n", mode, type_string, group_size, bits, batched); } auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(M, c++); add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qmm( const array& x, const array& w, const array& scales, const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && (env::enable_tf32() || x.dtype() != float32)) { return gather_qmm_nax( /* const array& x = */ x, /* const array& w = */ w, /* const array& scales = */ scales, /* const std::optional& biases = */ biases, /* const array& lhs_indices = */ lhs_indices, /* const array& rhs_indices = */ rhs_indices, /* array& out = */ out, /* bool transpose = */ transpose, /* int group_size = */ group_size, /* int bits = */ bits, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, /* const std::string& mode = */ mode); } int B = out.size() / M / N; int wm = 2; int wn = 2; int bm = 32; int bn = 32; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); bool aligned = N % 32 == 0; std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"), type_string, "_gs_", group_size, "_b_", bits, transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); MTL::ComputePipelineState* kernel; if (transpose) { kernel = get_quantized_kernel_wrapped( d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned); } else { kernel = get_quantized_kernel_wrapped( d, kname, "gather_qmm_n", mode, type_string, group_size, bits); } auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(M, c++); c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; int bn = 8; int bk = 32; MTL::Size group_dims(bk, 2, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); bool fast = N % bn == 0 && K % 512 == 0; concatenate( kname, mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"), type_string, "_gs_", group_size, "_b_", bits); auto kernel = get_quantized_kernel_wrapped( d, kname, (fast ? "gather_qmv_fast" : "gather_qmv"), mode, type_string, group_size, bits); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qvm( const array& x, const array& w, const array& scales, const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { int B = out.size() / M / N; constexpr int num_simdgroups = 2; constexpr int bk = 32; int bn = std::min(group_size, 32) * num_simdgroups; MTL::Size group_dims(bk, num_simdgroups, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + "_gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); auto kernel = get_quantized_kernel_wrapped( d, kname, "gather_qvm", mode, type_string, group_size, bits); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int c = 0; compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases) { compute_encoder.set_input_array(*biases, c++); } compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); compute_encoder.set_bytes(N, c++); c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++); add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qmm_rhs_nax( const array& x_, const array& w_, const array& scales_, const std::optional& biases_, const array& indices_, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string mode) { // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); // Broadcast x with indices. If we are here that means lhs_indices were not // provided so the lhs_indices are implied to be the shape of x broadcasted // with rhs_indices. We need only broadcast x and copy it as if applying the // lhs_indices. auto broadcast_with_indices = [&d, &s, &indices](const array& x) { if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { return ensure_row_contiguous(x, d, s); } auto x_shape = indices.shape(); x_shape.push_back(x.shape(-2)); x_shape.push_back(x.shape(-1)); array new_x(std::move(x_shape), x.dtype(), nullptr, {}); broadcast(x, new_x); return ensure_row_contiguous(new_x, d, s); }; // Normalize the input arrays array x = broadcast_with_indices(x_); array w = ensure_row_contiguous(w_, d, s); array scales = ensure_row_contiguous(scales_, d, s); // TODO: Tune the block sizes int bm = 64, bn = 64, bk = 64; int wm = 2, wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Make the kernel name std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"), type_string, "_gs_", group_size, "_b_", bits, "_bm_", bm, "_bn_", bn, "_bk_", bk, "_wm_", wm, "_wn_", wn); metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, kname, "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_gather_qmm_nax_kernel( d, kname, hash_name, func_consts, x, group_size, bits, mode, bm, bn, bk, wm, wn, transpose); compute_encoder.set_compute_pipeline_state(kernel); MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); int c = 0; compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases_) { array biases = ensure_row_contiguous(*biases_, d, s); compute_encoder.set_input_array(biases, c++); } compute_encoder.set_input_array(indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(M, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(K, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void gather_qmm_rhs( const array& x_, const array& w_, const array& scales_, const std::optional& biases_, const array& indices_, array& out, bool transpose, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string mode) { if (metal::is_nax_available() && transpose && (env::enable_tf32() || x_.dtype() != float32)) { return gather_qmm_rhs_nax( /* const array& x_ = */ x_, /* const array& w_ = */ w_, /* const array& scales_ = */ scales_, /* const std::optional& biases_ = */ biases_, /* const array& indices_ = */ indices_, /* array& out = */ out, /* bool transpose = */ transpose, /* int group_size = */ group_size, /* int bits = */ bits, /* int M = */ M, /* int N = */ N, /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, /* const std::string mode = */ mode); } // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); // Broadcast x with indices. If we are here that means lhs_indices were not // provided so the lhs_indices are implied to be the shape of x broadcasted // with rhs_indices. We need only broadcast x and copy it as if applying the // lhs_indices. auto broadcast_with_indices = [&d, &s, &indices](const array& x) { if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { return ensure_row_contiguous(x, d, s); } auto x_shape = indices.shape(); x_shape.push_back(x.shape(-2)); x_shape.push_back(x.shape(-1)); array new_x(std::move(x_shape), x.dtype(), nullptr, {}); broadcast(x, new_x); return ensure_row_contiguous(new_x, d, s); }; // Normalize the input arrays array x = broadcast_with_indices(x_); array w = ensure_row_contiguous(w_, d, s); array scales = ensure_row_contiguous(scales_, d, s); // TODO: Tune the block sizes int bm = 16, bn = 32, bk = 32; int wm = 1, wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; // Make the kernel name std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( kname, mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"), type_string, "_gs_", group_size, "_b_", bits, "_bm_", bm, "_bn_", bn, "_bk_", bk, "_wm_", wm, "_wn_", wn); metal::MTLFCList func_consts = { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, }; // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( hash_name, kname, "_align_M_", align_M ? 't' : 'n', "_align_N_", align_N ? 't' : 'n', "_align_K_", align_K ? 't' : 'n'); // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_gather_qmm_kernel( d, kname, hash_name, func_consts, x, group_size, bits, mode, bm, bn, bk, wm, wn, transpose); compute_encoder.set_compute_pipeline_state(kernel); MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); int c = 0; compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(w, c++); compute_encoder.set_input_array(scales, c++); if (biases_) { array biases = ensure_row_contiguous(*biases_, d, s); compute_encoder.set_input_array(biases, c++); } compute_encoder.set_input_array(indices, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(M, c++); compute_encoder.set_bytes(N, c++); compute_encoder.set_bytes(K, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void dispatch_qmv( const array& x, const array& w, const array& scales, const std::optional& biases, array& out, int group_size, int bits, int M, int N, int K, metal::Device& d, const Stream& s, const std::string& mode) { // It is a qmv with a small inner dimension so route to qmv_quad kernel if ((K == 128 || K == 64) && is_power_of_2(bits)) { qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); return; } // Run of the mill qmv qmv(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); out.set_data(allocator::malloc(out.nbytes())); // Make sure the last two dims of x and w, s, b are contiguous. This should // be relaxed for x. array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); std::optional biases = std::nullopt; if (inputs.size() == 4) { biases = ensure_row_contiguous_matrix(inputs[3], d, s); } // Extract the matmul shapes bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; int K = x.shape(-1); int M = non_batched ? x.size() / K : x.shape(-2); int N = out.shape(-1); int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; auto mode = quantization_mode_to_string(mode_); // It is a matrix matrix product. if (M >= vector_limit) { qmm(x, w, scales, biases, out, transpose_, group_size_, bits_, M, N, K, d, s, mode); return; } // Run of the mill qmv if (transpose_) { dispatch_qmv( x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } // Run of the mill qvm if (K < 1024) { qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } // Qvm with large dimension so route to a split K kernel for more parallelism qvm_split_k( x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); out.set_data(allocator::malloc(out.nbytes())); array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); std::optional biases = std::nullopt; if (inputs.size() == 6) { biases = ensure_row_contiguous_matrix(inputs[3], d, s); } const array& lhs_indices = inputs[inputs.size() - 2]; const array& rhs_indices = inputs[inputs.size() - 1]; int K = x.shape(-1); int M = x.shape(-2); int N = out.shape(-1); int B = out.size() / M / N; int E = w.size() / w.shape(-1) / w.shape(-2); int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; auto mode = quantization_mode_to_string(mode_); // We are walking x in order and w is also in order so we can batch up the // matmuls and reuse reading x and w. // // TODO: Tune 16 and 4 here a bit better. if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 4) { gather_qmm_rhs( x, w, scales, biases, rhs_indices, out, transpose_, group_size_, bits_, x.size() / K, N, K, d, s, mode); return; } // It is a matrix matrix product if (M >= vector_limit) { gather_qmm( x, w, scales, biases, lhs_indices, rhs_indices, out, transpose_, group_size_, bits_, M, N, K, d, s, mode); return; } if (transpose_) { gather_qmv( x, w, scales, biases, lhs_indices, rhs_indices, out, group_size_, bits_, M, N, K, d, s, mode); return; } gather_qvm( x, w, scales, biases, lhs_indices, rhs_indices, out, group_size_, bits_, M, N, K, d, s, mode); } void quantize_dequantize( const array& in, array& out, std::string mode, int group_size, int bits, metal::Device& d, const Stream& s) { auto& compute_encoder = d.get_command_encoder(s.index); auto w = ensure_row_contiguous(in, d, s); compute_encoder.set_input_array(w, 0); compute_encoder.set_output_array(out, 1); auto type_string = get_type_string(in.dtype()); std::string kname; concatenate( kname, mode + "_quantize_dequantize_", type_string, "_gs_", group_size, "_b_", bits); auto kernel = get_quantized_kernel_wrapped( d, kname, "quantize_dequantize", mode, type_string, group_size, bits); compute_encoder.set_compute_pipeline_state(kernel); constexpr int uint8_per_uint32 = 4; constexpr int simd_size = 32; int packs_per_int = (bits == 3 || bits == 5) ? 8 : bits == 6 ? 4 : 8 / bits; int per_thread = std::max(group_size / simd_size, 1); size_t nthreads = w.size() / per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; } auto group_dims = MTL::Size(thread_group_size, 1, 1); bool use_2d = nthreads > UINT_MAX; auto grid_shape = w.shape(); grid_shape.back() /= per_thread; MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto mode = quantization_mode_to_string(mode_); bool w_quantized = (inputs[1].dtype() == uint32); if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(allocator::malloc(out.nbytes())); bool donate_x = inputs[0].is_donatable(); array x = ensure_row_contiguous(inputs[0], d, s); // If x is a copy it should be donatable donate_x |= x.is_donatable(); auto xhat = donate_x ? x : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s); // Make sure the last two dims of w and s are contiguous array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); bool non_batched = w.ndim() == 2; int K = x.shape(-1); int M = non_batched ? x.size() / K : x.shape(-2); int N = out.shape(-1); dispatch_qmv( xhat, w, scales, std::nullopt, out, group_size_, bits_, M, N, K, d, s, mode); return; } else { throw std::runtime_error("[QQMatmul] NYI for the general case"); } } void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& w_pre = inputs[0]; auto& out = outputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); auto w = ensure_row_contiguous(w_pre, d, s); if (dequantize_) { auto scales = ensure_row_contiguous(inputs[1], d, s); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(biases, 2); } compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(scales, 1); compute_encoder.set_output_array(out, 3); } else { auto& scales = outputs[1]; scales.set_data(allocator::malloc(scales.nbytes())); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(allocator::malloc(biases.nbytes())); compute_encoder.set_output_array(biases, 3); } compute_encoder.set_input_array(w, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(scales, 2); } auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); auto mode = quantization_mode_to_string(mode_); std::string kname; concatenate( kname, mode + (dequantize_ ? "_dequantize" : "_quantize"), "_", type_string, "_gs_", group_size_, "_b_", bits_); auto kernel = get_quantized_kernel_wrapped( d, kname, dequantize_ ? "dequantize" : "quantize", mode, type_string, group_size_, bits_); compute_encoder.set_compute_pipeline_state(kernel); // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; constexpr int simd_size = 32; int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 : bits_ == 6 ? 4 : 8 / bits_; int per_thread = dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1); size_t nthreads = dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; } auto group_dims = MTL::Size(thread_group_size, 1, 1); bool use_2d = nthreads > UINT_MAX; auto grid_shape = w.shape(); if (dequantize_) { grid_shape.back() *= uint8_per_uint32; } else { grid_shape.back() /= per_thread; } MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } void fast::ConvertFP8::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& in = inputs[0]; auto& out = outputs[0]; unary_op_gpu(inputs, out, name(), stream()); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/reduce.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { struct RowReduceArgs { // Input shape and strides not including the reduction axes Shape shape; Strides strides; int ndim; // Input shape and strides for the reduction axes Shape reduce_shape; Strides reduce_strides; int reduce_ndim; // The number of rows we are reducing. Namely prod(reduce_shape). size_t non_row_reductions; // The size of the row. size_t row_size; RowReduceArgs( const array& in, const ReductionPlan& plan, const std::vector& axes) { row_size = plan.shape.back(); reduce_shape = plan.shape; reduce_strides = plan.strides; reduce_shape.pop_back(); reduce_strides.pop_back(); reduce_ndim = reduce_shape.size(); non_row_reductions = 1; for (auto s : reduce_shape) { non_row_reductions *= s; } std::tie(shape, strides) = shapes_without_reduction_axes(in, axes); std::tie(shape, strides) = collapse_contiguous_dims(shape, strides); ndim = shape.size(); } void encode(CommandEncoder& compute_encoder) { // Push 0s to avoid encoding empty vectors. if (reduce_ndim == 0) { reduce_shape.push_back(0); reduce_strides.push_back(0); } if (ndim == 0) { shape.push_back(0); strides.push_back(0); } compute_encoder.set_bytes(row_size, 2); compute_encoder.set_bytes(non_row_reductions, 3); compute_encoder.set_vector_bytes(shape, 4); compute_encoder.set_vector_bytes(strides, 5); compute_encoder.set_bytes(ndim, 6); compute_encoder.set_vector_bytes(reduce_shape, 7); compute_encoder.set_vector_bytes(reduce_strides, 8); compute_encoder.set_bytes(reduce_ndim, 9); if (reduce_ndim == 0) { reduce_shape.pop_back(); reduce_strides.pop_back(); } if (ndim == 0) { shape.pop_back(); strides.pop_back(); } } }; struct ColReduceArgs { // Input shape and strides not including the reduction axes Shape shape; Strides strides; int ndim; // Input shape and strides for the reduction axes Shape reduce_shape; Strides reduce_strides; int reduce_ndim; // The number of column reductions we are doing. Namely prod(reduce_shape). size_t non_col_reductions; // The size of the contiguous column reduction. size_t reduction_size; int64_t reduction_stride; ColReduceArgs( const array& in, const ReductionPlan& plan, const std::vector& axes) { reduction_size = plan.shape.back(); reduction_stride = plan.strides.back(); reduce_shape = plan.shape; reduce_strides = plan.strides; reduce_shape.pop_back(); reduce_strides.pop_back(); reduce_ndim = reduce_shape.size(); non_col_reductions = 1; for (auto s : reduce_shape) { non_col_reductions *= s; } // We 'll use a stride_back variable because strides.back() could be 0 but // yet we may have removed the appropriate amount of elements. It is safe // to compute the stride by multiplying shapes (while < reduction_stride) // because it is a contiguous section. int64_t stride_back = 1; std::tie(shape, strides) = shapes_without_reduction_axes(in, axes); while (!shape.empty() && stride_back < reduction_stride) { stride_back *= shape.back(); shape.pop_back(); strides.pop_back(); } std::tie(shape, strides) = collapse_contiguous_dims(shape, strides); ndim = shape.size(); } /** * Create the col reduce arguments for reducing the 1st axis of the row * contiguous intermediate array. */ ColReduceArgs(const array& intermediate) { assert(intermediate.flags().row_contiguous); reduction_size = intermediate.shape(0); reduction_stride = intermediate.size() / reduction_size; non_col_reductions = 1; reduce_ndim = 0; ndim = 0; } void encode(CommandEncoder& compute_encoder) { // Push 0s to avoid encoding empty vectors. if (reduce_ndim == 0) { reduce_shape.push_back(0); reduce_strides.push_back(0); } if (ndim == 0) { shape.push_back(0); strides.push_back(0); } compute_encoder.set_bytes(reduction_size, 2); compute_encoder.set_bytes(reduction_stride, 3); compute_encoder.set_vector_bytes(shape, 4); compute_encoder.set_vector_bytes(strides, 5); compute_encoder.set_bytes(ndim, 6); compute_encoder.set_vector_bytes(reduce_shape, 7); compute_encoder.set_vector_bytes(reduce_strides, 8); compute_encoder.set_bytes(reduce_ndim, 9); compute_encoder.set_bytes(non_col_reductions, 10); if (reduce_ndim == 0) { reduce_shape.pop_back(); reduce_strides.pop_back(); } if (ndim == 0) { shape.pop_back(); strides.pop_back(); } } }; } // namespace inline auto safe_div(size_t n, size_t m) { return m == 0 ? 0 : (n + m - 1) / m; } inline auto safe_divup(size_t n, size_t m) { return safe_div(n, m) * m; } inline bool is_64b_int(Dtype dtype) { return dtype == int64 || dtype == uint64; } inline bool is_64b_dtype(Dtype dtype) { return dtype == int64 || dtype == uint64 || dtype == complex64; } inline int get_kernel_reduce_ndim(int reduce_ndim) { if (reduce_ndim <= 1) { return 1; } else if (reduce_ndim == 2) { return 2; } else { return 5; } } inline int threadgroup_size_from_row_size(int row_size) { // 1 simdgroup per row smallish rows if (row_size <= 512) { return 32; } // 2 simdgroups per row for medium rows if (row_size <= 1024) { return 128; } // up to 32 simdgroups after that int thread_group_size; thread_group_size = (row_size + REDUCE_N_READS - 1) / REDUCE_N_READS; thread_group_size = ((thread_group_size + 31) / 32) * 32; thread_group_size = std::min(1024, thread_group_size); return thread_group_size; } inline auto output_grid_for_col_reduce( const array& out, const ColReduceArgs& args) { auto out_shape = out.shape(); auto out_strides = out.strides(); while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { out_shape.pop_back(); out_strides.pop_back(); } return get_2d_grid_dims(out_shape, out_strides); } std::pair remap_reduce_types( const array& in, const std::string& op_name) { if (op_name == "sum" || op_name == "prod") { if (issubdtype(in.dtype(), integer)) { switch (in.dtype()) { case uint8: return {uint8, uint32}; case uint16: return {uint16, uint32}; case uint32: return {uint32, uint32}; case uint64: return {uint64, uint64}; case int8: return {int8, int32}; case int16: return {int16, int32}; case int32: return {int32, int32}; case int64: return {int64, int64}; default: throw std::runtime_error("Unsupported integer type"); } } if (in.dtype() == bool_) { return {int8, int32}; } return {in.dtype(), in.dtype()}; } else if (op_name == "and" || op_name == "or") { if (in.dtype().size() == 1) { return {bool_, bool_}; } else if (in.dtype().size() == 2) { return {int16, bool_}; } else if (in.dtype().size() == 4) { return {int32, bool_}; } else { return {int64, bool_}; } } return {in.dtype(), in.dtype()}; } void init_reduce( array& out, const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [_, out_type] = remap_reduce_types(out, op_name); const std::string func_name = "init_reduce"; std::string kname = func_name; concatenate(kname, "_", op_name, type_to_name(out_type)); auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_output_array(out, 0); compute_encoder.dispatch_threads(grid_dims, group_dims); } void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { // Set the kernel auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "all_reduce"; std::string kname = func_name; concatenate(kname, "_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, "int64_t"); compute_encoder.set_compute_pipeline_state(kernel); size_t in_size = in.size(); // Small array so dispatch a single threadgroup if (in_size <= REDUCE_N_READS * 1024) { int threadgroup_size = (in_size + REDUCE_N_READS - 1) / REDUCE_N_READS; threadgroup_size = ((threadgroup_size + 31) / 32) * 32; MTL::Size grid_dims(threadgroup_size, 1, 1); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(in_size, 2); compute_encoder.set_bytes(in_size, 3); compute_encoder.dispatch_threads(grid_dims, grid_dims); } // We need multiple threadgroups so we 'll do it in 2 passes. else { int n_rows, threadgroup_2nd_pass; // Less than 2**26 bytes if (in.nbytes() <= (1 << 26)) { n_rows = 32 * REDUCE_N_READS; threadgroup_2nd_pass = 32; } // Really large matrix so parallelize as much as possible else { n_rows = 1024 * REDUCE_N_READS; threadgroup_2nd_pass = 1024; } // Allocate an intermediate tensor to hold results if needed array intermediate({n_rows}, out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // 1st pass size_t row_size = (in_size + n_rows - 1) / n_rows; int threadgroup_size = std::min((row_size + REDUCE_N_READS - 1) / REDUCE_N_READS, 1024ul); threadgroup_size = ((threadgroup_size + 31) / 32) * 32; MTL::Size grid_dims(threadgroup_size, n_rows, 1); MTL::Size group_dims(threadgroup_size, 1, 1); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_bytes(in_size, 2); compute_encoder.set_bytes(row_size, 3); compute_encoder.dispatch_threads(grid_dims, group_dims); // 2nd pass std::string kname_2nd_pass = func_name; concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate)); auto kernel_2nd_pass = get_reduce_kernel( d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t"); compute_encoder.set_compute_pipeline_state(kernel_2nd_pass); size_t intermediate_size = n_rows; grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(intermediate_size, 2); compute_encoder.set_bytes(intermediate_size, 3); compute_encoder.dispatch_threads(grid_dims, group_dims); } } void row_reduce_small( const array& in, array& out, const std::string& op_name, RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_small"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims MTL::Size grid_dims; MTL::Size group_dims; if ((args.non_row_reductions < 32 && args.row_size <= 8) || args.non_row_reductions <= 8) { grid_dims = get_2d_grid_dims(out.shape(), out.strides()); group_dims = MTL::Size((grid_dims.width < 1024) ? grid_dims.width : 1024, 1, 1); } else { auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); grid_dims = MTL::Size(32, out_grid_size.width, out_grid_size.height); group_dims = MTL::Size(32, 1, 1); } // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_simple( const array& in, array& out, const std::string& op_name, RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { // Set the kernel auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_simple"; std::string kname = func_name; concatenate(kname, "_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, "size_t"); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims size_t row_size = args.row_size; size_t out_size = out.size(); auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); out_grid_size.width = (out_grid_size.width + REDUCE_N_WRITES - 1) / REDUCE_N_WRITES; int threadgroup_size = threadgroup_size_from_row_size(row_size); if (in.itemsize() == 8) { threadgroup_size = std::min(threadgroup_size, 512); } MTL::Size grid_dims( threadgroup_size, out_grid_size.width, out_grid_size.height); MTL::Size group_dims(threadgroup_size, 1, 1); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(row_size, 2); compute_encoder.set_bytes(out_size, 3); compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_looped( const array& in, array& out, const std::string& op_name, RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "row_reduce_looped"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); int threadgroup_size = threadgroup_size_from_row_size(args.row_size); MTL::Size grid_dims( threadgroup_size, out_grid_size.width, out_grid_size.height); MTL::Size group_dims(threadgroup_size, 1, 1); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { // Prepare the arguments for the kernel RowReduceArgs args(in, plan, axes); // Case 1: The row is small if (args.row_size <= 64) { return row_reduce_small(in, out, op_name, args, compute_encoder, d, s); } // Case 2: Contiguous reduce without non-row reductions if (plan.type == ContiguousReduce && args.reduce_ndim == 0 && in.size() / args.row_size >= 32) { return row_reduce_simple(in, out, op_name, args, compute_encoder, d, s); } // Case 3: General row reduce including non-row reductions return row_reduce_looped(in, out, op_name, args, compute_encoder, d, s); } void strided_reduce_small( const array& in, array& out, const std::string& op_name, ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Figure out the grid dims MTL::Size grid_dims, group_dims; // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); args.reduce_ndim++; int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "col_reduce_small"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); const int n_reads = 4; size_t reduction_stride_blocks = (args.reduction_stride + n_reads - 1) / n_reads; size_t total = args.reduction_size * args.non_col_reductions; size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul); size_t threadgroup_y = std::min( 8ul, std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total)); group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1); grid_dims = output_grid_for_col_reduce(out, args); grid_dims = MTL::Size( (reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x, grid_dims.width, grid_dims.height); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void strided_reduce_longcolumn( const array& in, array& out, const std::string& op_name, ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [in_type, out_type] = remap_reduce_types(in, op_name); size_t total_reduction_size = args.reduction_size * args.non_col_reductions; size_t outer_blocks = 32; if (total_reduction_size >= 32768) { outer_blocks = 128; } // Prepare the temporary accumulator Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.push_back(outer_blocks); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); args.reduce_ndim++; // Figure out the grid dims size_t out_size = out.size(); size_t threadgroup_x = args.reduction_stride; size_t threadgroup_y = (args.non_col_reductions * args.reduction_size + outer_blocks - 1) / outer_blocks; threadgroup_y = std::min(32ul, threadgroup_y); auto out_grid_size = output_grid_for_col_reduce(out, args); MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks); MTL::Size group_dims(threadgroup_x, threadgroup_y, 1); // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_longcolumn"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "int64_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); args.encode(compute_encoder); compute_encoder.set_bytes(out_size, 11); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Make the 2nd pass arguments and grid_dims ColReduceArgs second_args(intermediate); second_args.reduce_shape.push_back(outer_blocks); second_args.reduce_strides.push_back(out.size()); second_args.reduce_ndim++; int BN = 32; grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1); group_dims = MTL::Size(256, 1, 1); // Set the 2nd kernel func_name = "col_reduce_looped"; kname = func_name; large = intermediate.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate)); kernel = get_reduce_kernel( d, kname, func_name, op_name, intermediate.dtype(), out_type, large ? "int64_t" : "int", 1, 32, 32); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); second_args.encode(compute_encoder); compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_looped( const array& in, array& out, const std::string& op_name, ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); args.reduce_ndim++; // Figure out the grid dims auto out_grid_size = output_grid_for_col_reduce(out, args); int BN = 32; int BM = 1024 / BN; int threadgroup_size = 8 * 32; MTL::Size grid_dims( threadgroup_size * ((args.reduction_stride + BN - 1) / BN), out_grid_size.width, out_grid_size.height); MTL::Size group_dims(threadgroup_size, 1, 1); // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_looped"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_", std::to_string(BM), "_", std::to_string(BN), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "int64_t" : "int", n, BM, BN); compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_2pass( const array& in, array& out, const std::string& op_name, ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Prepare the temporary accumulator Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.push_back(32); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); args.reduce_ndim++; // Figure out the grid dims size_t out_size = out.size() / args.reduction_stride; auto out_grid_size = output_grid_for_col_reduce(out, args); int outer_blocks = 32; int BN = 32; int BM = 1024 / BN; int threadgroup_size = 8 * 32; MTL::Size grid_dims( threadgroup_size * ((args.reduction_stride + BN - 1) / BN), out_grid_size.width * outer_blocks, out_grid_size.height); MTL::Size group_dims(threadgroup_size, 1, 1); // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_2pass"; std::string kname = func_name; bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate( kname, "_", std::to_string(n), "_", std::to_string(BM), "_", std::to_string(BN), "_reduce_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel( d, kname, func_name, op_name, in_type, out_type, large ? "int64_t" : "int", n, BM, BN); compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); args.encode(compute_encoder); compute_encoder.set_bytes(out_size, 11); compute_encoder.dispatch_threads(grid_dims, group_dims); // Make the 2nd pass arguments and grid_dims ColReduceArgs second_args(intermediate); second_args.reduce_shape.push_back(outer_blocks); second_args.reduce_strides.push_back(out.size()); second_args.reduce_ndim++; grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1); // Set the 2nd kernel func_name = "col_reduce_looped"; kname = func_name; large = intermediate.size() > INT32_MAX; if (large) { kname += "_large"; } concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate)); kernel = get_reduce_kernel( d, kname, func_name, op_name, intermediate.dtype(), out_type, large ? "int64_t" : "int", 1, 32, 32); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); second_args.encode(compute_encoder); compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { // Prepare the arguments for the kernel ColReduceArgs args(in, plan, axes); // Small column if (args.reduction_size * args.non_col_reductions < 32) { return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s); } // Long column but small row if (args.reduction_stride < 32 && args.reduction_size * args.non_col_reductions >= 1024) { return strided_reduce_longcolumn( in, out, op_name, args, compute_encoder, d, s); } if (args.reduction_size * args.non_col_reductions > 256 && out.size() / 32 < 1024) { return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s); } return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s); } void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); array in = inputs[0]; // Make sure no identity reductions trickle down here assert(!axes_.empty()); assert(out.size() != in.size()); // Continue with reduction operation // Minimum of 4 bytes since we use size 4 structs for all reduce // and metal will complain o/w size_t min_bytes = std::max(out.nbytes(), 4ul); out.set_data(allocator::malloc(min_bytes)); std::string op_name; switch (reduce_type_) { case Reduce::And: op_name = "and"; break; case Reduce::Or: op_name = "or"; break; case Reduce::Sum: op_name = "sum"; break; case Reduce::Prod: op_name = "prod"; break; case Reduce::Min: op_name = out.dtype() == bool_ ? "and" : "min"; break; case Reduce::Max: op_name = out.dtype() == bool_ ? "or" : "max"; break; } // Initialize output auto& s = stream(); auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); // Reduce if (in.size() > 0) { ReductionPlan plan = get_reduction_plan(in, axes_); // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. // // TODO: This can be avoided by making the output have the same strides as // input for the axes with stride smaller than the minimum reduction // stride. if (plan.type == GeneralReduce) { array in_copy = contiguous_copy_gpu(in, s); d.add_temporary(in_copy, s.index); in = in_copy; plan = get_reduction_plan(in, axes_); } // Reducing over everything and the data is all there no broadcasting or // slicing etc. if (plan.type == ContiguousAllReduce) { all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); } // At least the last dimension is row contiguous and we are reducing over // the last dim. else if ( plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { row_reduce_general_dispatch( in, out, op_name, plan, axes_, compute_encoder, d, s); } // At least the last two dimensions are contiguous and we are doing a // strided reduce over these. else if ( plan.type == ContiguousStridedReduce || plan.type == GeneralStridedReduce) { strided_reduce_general_dispatch( in, out, op_name, plan, axes_, compute_encoder, d, s); } } // Nothing to reduce just initialize the output else { init_reduce(out, op_name, compute_encoder, d, s); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/reduce.h ================================================ // Copyright @ 2023 - 2024 Apple Inc. #pragma once #include "mlx/backend/common/reduce.h" #include "mlx/backend/metal/device.h" #include "mlx/stream.h" namespace mlx::core { using metal::CommandEncoder; void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); void row_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); void strided_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/metal/resident.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" namespace mlx::core::metal { ResidencySet::ResidencySet(MTL::Device* d) { if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { return; } else if (__builtin_available(macOS 15, iOS 18, *)) { auto pool = new_scoped_memory_pool(); auto desc = MTL::ResidencySetDescriptor::alloc()->init(); NS::Error* error; wired_set_ = d->newResidencySet(desc, &error); desc->release(); if (!wired_set_) { std::ostringstream msg; msg << "[metal::Device] Unable to construct residency set.\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } throw std::runtime_error(msg.str()); } wired_set_->requestResidency(); } } void ResidencySet::insert(MTL::Allocation* buf) { if (!wired_set_) { return; } if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { wired_set_->addAllocation(buf); wired_set_->commit(); } else { unwired_set_.insert(buf); } } void ResidencySet::erase(MTL::Allocation* buf) { if (!wired_set_) { return; } if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) { unwired_set_.erase(it); } else { wired_set_->removeAllocation(buf); wired_set_->commit(); } } void ResidencySet::resize(size_t size) { if (!wired_set_) { return; } if (capacity_ == size) { return; } capacity_ = size; size_t current_size = wired_set_->allocatedSize(); if (current_size < size) { auto pool = new_scoped_memory_pool(); // Add unwired allocations to the set for (auto it = unwired_set_.begin(); it != unwired_set_.end();) { auto buf_size = (*it)->allocatedSize(); if (current_size + buf_size > size) { it++; } else { current_size += buf_size; wired_set_->addAllocation(*it); unwired_set_.erase(it++); } } wired_set_->commit(); } else if (current_size > size) { auto pool = new_scoped_memory_pool(); // Remove wired allocations until under capacity auto allocations = wired_set_->allAllocations(); auto num_allocations = wired_set_->allocationCount(); for (int i = 0; i < num_allocations && current_size > size; ++i) { auto buf = static_cast(allocations->object(i)); wired_set_->removeAllocation(buf); current_size -= buf->allocatedSize(); unwired_set_.insert(buf); } wired_set_->commit(); } } ResidencySet::~ResidencySet() { if (wired_set_) { auto pool = new_scoped_memory_pool(); wired_set_->release(); } } } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/resident.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/backend/metal/device.h" namespace mlx::core::metal { class ResidencySet { public: ResidencySet(MTL::Device* d); ~ResidencySet(); ResidencySet(const ResidencySet&) = delete; ResidencySet& operator=(const ResidencySet&) = delete; const MTL::ResidencySet* mtl_residency_set() { return wired_set_; } void insert(MTL::Allocation* buf); void erase(MTL::Allocation* buf); void resize(size_t size); private: MTL::ResidencySet* wired_set_{nullptr}; std::unordered_set unwired_set_; size_t capacity_{0}; }; } // namespace mlx::core::metal ================================================ FILE: mlx/backend/metal/rope.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" namespace mlx::core::fast { constexpr int n_per_thread = 4; bool RoPE::use_fallback(Stream s) { return s.device == Device::cpu; } void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(outputs.size() == 1); auto& in = inputs[0]; auto& out = outputs[0]; auto& s = out.primitive().stream(); auto& d = metal::device(s.device); int64_t strides[3]; int64_t out_strides[3]; bool donated = false; int ndim = in.ndim(); int B = in.shape(0); int T = in.shape(-2); int D = in.shape(-1); size_t mat_size = T * D; bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX; int dispatch_ndim = ndim; while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { dispatch_ndim--; } int N = 1; for (int i = 1; i < (ndim - 2); ++i) { N *= in.shape(i); } bool head_seq_transpose = false; if (dims_ < D) { donated = true; auto ctype = (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; copy_gpu(in, out, ctype, s); strides[0] = mat_size; strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } else if (in.flags().row_contiguous) { if (in.is_donatable()) { donated = true; out.copy_shared_buffer(in); } else { out.set_data(allocator::malloc(out.nbytes())); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs out.set_data(allocator::malloc(out.nbytes())); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if ( ndim == 4 && // batch dim is regularly strided in.strides()[0] == T * N * D && // sequence and head dimensions are transposed in.strides()[1] == D && in.strides()[2] == N * D) { head_seq_transpose = true; out.set_data(allocator::malloc(out.nbytes())); strides[0] = in.strides()[1]; strides[1] = in.strides()[2]; strides[2] = in.strides()[3]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated donated = true; copy_gpu(in, out, CopyType::General, s); strides[0] = mat_size; strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } out_strides[0] = mat_size; out_strides[1] = out.strides()[ndim - 2]; out_strides[2] = out.strides()[ndim - 1]; // Special case for inference (single time step, contiguous, one offset) auto& offset = inputs[1]; bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1; bool with_freqs = inputs.size() == 3; std::string kname; concatenate( kname, "rope_", single ? "single_" : "", (with_freqs) ? "freqs_" : "", large ? "large_" : "", type_to_name(in)); std::string hash_name; concatenate( hash_name, kname, "_", forward_ ? "" : "vjp_", traditional_ ? "traditional_" : "", head_seq_transpose ? "transpose" : ""); metal::MTLFCList func_consts = { {&forward_, MTL::DataType::DataTypeBool, 1}, {&traditional_, MTL::DataType::DataTypeBool, 2}, {&head_seq_transpose, MTL::DataType::DataTypeBool, 3}}; auto kernel = d.get_kernel(kname, hash_name, func_consts); auto& compute_encoder = d.get_command_encoder(s.index); float base = std::log2(base_); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_input_array(offset, 2); compute_encoder.set_bytes(scale_, 3); MTL::Size group_dims; MTL::Size grid_dims; if (single) { compute_encoder.set_bytes(out_strides, 1, 4); uint32_t dim0 = dims_ / 2; group_dims = get_block_dims(dim0, N, 1); grid_dims = MTL::Size(dim0, N, 1); } else { compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(out_strides, 3, 5); int64_t offset_stride = 0; if (offset.ndim() > 0) { offset_stride = offset.strides()[0]; } compute_encoder.set_bytes(offset_stride, 6); compute_encoder.set_bytes(N, 7); uint32_t dim0 = dims_ / 2; uint32_t dim1 = T; uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread); group_dims = get_block_dims(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2); } if (with_freqs) { auto& freqs = inputs[2]; compute_encoder.set_input_array(freqs, 10); auto freq_stride = freqs.strides()[0]; compute_encoder.set_bytes(freq_stride, 11); } else { compute_encoder.set_bytes(base, 10); } compute_encoder.dispatch_threads(grid_dims, group_dims); } } // namespace mlx::core::fast ================================================ FILE: mlx/backend/metal/scaled_dot_product_attention.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/utils.h" namespace mlx::core::fast { namespace { void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, const array& q, const array& k, const array& v, const float scale, array& o, bool do_causal_, const std::optional& mask, const std::optional& sinks) { using namespace mlx::steel; int wm = 4; int wn = 1; int bd = q.shape(-1); int bq = 64; int bk = 32; int B = q.shape(0); int H = q.shape(1); int D = q.shape(3); int gqa_factor = q.shape(1) / k.shape(1); int qL = q.shape(2); int kL = k.shape(2); const bool align_Q = (qL % bq) == 0; const bool align_K = (kL % bk) == 0; const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, {&has_sinks, MTL::DataType::DataTypeBool, 302}}; std::string base_name; concatenate( base_name, "steel_attention_", type_to_name(q), "_bq", bq, "_bk", bk, "_bd", bd, "_wm", wm, "_wn", wn, "_mask", type_to_name(has_mask ? *mask : q)); std::string hash_name; concatenate( hash_name, base_name, "_align_Q_", (align_Q ? 't' : 'n'), "_align_K_", (align_K ? 't' : 'n'), "_has_mask_", (has_mask ? 't' : 'n'), "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", (has_sinks ? 't' : 'n')); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_attention_nax_kernel( d, base_name, hash_name, func_consts, q, bq, bk, bd, wm, wn, (has_mask ? *mask : q)); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; const int NK = (kL + bk - 1) / bk; const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; AttnParams params{ /* int B = */ B, /* int H = */ H, /* int D = */ D, /* int qL = */ qL, /* int kL = */ kL, /* int gqa_factor = */ gqa_factor, /* float scale = */ scale, /* int NQ = */ NQ, /* int NK = */ NK, /* int NQ_aligned = */ NQ_aligned, /* int NK_aligned = */ NK_aligned, /* int qL_rem = */ (qL - NQ_aligned * bq), /* int kL_rem = */ (kL - NK_aligned * bk), /* int qL_off = */ (kL - qL), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); if (has_mask) { auto& m = *mask; AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { m.strides(0), m.strides(1), m.strides(2)}}; compute_encoder.set_bytes(mask_params, 5); compute_encoder.set_input_array(m, 6); } if (has_sinks) { compute_encoder.set_input_array(*sinks, 7); } MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void sdpa_full_self_attention_metal( const Stream& s, metal::Device& d, const array& q, const array& k, const array& v, const float scale, array& o, bool do_causal_, const std::optional& mask, const std::optional& sinks) { if (metal::is_nax_available() && q.shape(3) != 80 && (env::enable_tf32() || q.dtype() != float32)) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& q = */ q, /* const array& k = */ k, /* const array& v = */ v, /* const float scale = */ scale, /* array& o = */ o, /* bool do_causal_ = */ do_causal_, /* const std::optional& mask = */ mask, /* const std::optional& sinks = */ sinks); } using namespace mlx::steel; int wm = 4; int wn = 1; int bd = q.shape(-1); int bq = 32; int bk = bd < 128 ? 32 : 16; int B = q.shape(0); int H = q.shape(1); int D = q.shape(3); int gqa_factor = q.shape(1) / k.shape(1); int qL = q.shape(2); int kL = k.shape(2); const bool align_Q = (qL % bq) == 0; const bool align_K = (kL % bk) == 0; const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, {&has_sinks, MTL::DataType::DataTypeBool, 302}}; std::string base_name; concatenate( base_name, "steel_attention_", type_to_name(q), "_bq", bq, "_bk", bk, "_bd", bd, "_wm", wm, "_wn", wn, "_mask", type_to_name(has_mask ? *mask : q)); std::string hash_name; concatenate( hash_name, base_name, "_align_Q_", (align_Q ? 't' : 'n'), "_align_K_", (align_K ? 't' : 'n'), "_has_mask_", (has_mask ? 't' : 'n'), "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", (has_sinks ? 't' : 'n')); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_attention_kernel( d, base_name, hash_name, func_consts, q, bq, bk, bd, wm, wn, (has_mask ? *mask : q)); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; const int NK = (kL + bk - 1) / bk; const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; AttnParams params{ /* int B = */ B, /* int H = */ H, /* int D = */ D, /* int qL = */ qL, /* int kL = */ kL, /* int gqa_factor = */ gqa_factor, /* float scale = */ scale, /* int NQ = */ NQ, /* int NK = */ NK, /* int NQ_aligned = */ NQ_aligned, /* int NK_aligned = */ NK_aligned, /* int qL_rem = */ (qL - NQ_aligned * bq), /* int kL_rem = */ (kL - NK_aligned * bk), /* int qL_off = */ (kL - qL), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); if (has_mask) { auto& m = *mask; AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { m.strides(0), m.strides(1), m.strides(2)}}; compute_encoder.set_bytes(mask_params, 5); compute_encoder.set_input_array(m, 6); } if (has_sinks) { compute_encoder.set_input_array(*sinks, 7); } MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void sdpa_vector( const Stream& s, metal::Device& d, const array& q, const array& k, const array& v, array& out, float scale, bool do_causal, const std::optional& mask, const std::optional& sinks) { // Set the kernel name std::string kname; kname.reserve(64); kname += "sdpa_vector_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); kname += "_"; kname += std::to_string(v.shape(-1)); // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1); bool has_mask = mask.has_value(); bool bool_mask = has_mask && (*mask).dtype() == bool_; bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, {&has_sinks, MTL::DataType::DataTypeBool, 25}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; hash_name += has_sinks ? "_sinks" : "_nosinks"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(gqa_factor, 4); compute_encoder.set_bytes(N, 5); compute_encoder.set_bytes(k_head_stride, 6); compute_encoder.set_bytes(k_seq_stride, 7); compute_encoder.set_bytes(v_head_stride, 8); compute_encoder.set_bytes(v_seq_stride, 9); compute_encoder.set_bytes(scale, 10); if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; int32_t head_stride = m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); } if (has_sinks) { compute_encoder.set_input_array(*sinks, 16); compute_encoder.set_bytes(q.shape(1), 17); } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void sdpa_vector_2pass( const Stream& s, metal::Device& d, const array& q, const array& k, const array& v, array& out, float scale, bool do_causal, const std::optional& mask, const std::optional& sinks) { // Set the kernel name std::string kname; kname.reserve(64); kname += "sdpa_vector_2pass_1_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); kname += "_"; kname += std::to_string(v.shape(-1)); // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); int n_simds = gqa_factor * q.shape(2); char devc = d.get_architecture().back(); int N = k.shape(2); int blocks; if (devc == 's') { blocks = 64; if (N > 1024 && n_simds > 4) { if (N <= 8192) { blocks = 128; } else if (N <= 32768) { blocks = 256; } else if (N <= 65536) { blocks = 512; } else { blocks = 1024; } } } else if (devc == 'd') { blocks = 128; if (n_simds <= 2 && N > 8192) { blocks = 256; } else if (n_simds >= 6) { if (N >= 16384 && N < 65536) { blocks = 512; } else if (N >= 65536) { blocks = 1024; } } } else { if (n_simds >= 4) { blocks = 64; } else { blocks = 32; } } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(32, gqa_factor, q.shape(2)); MTL::Size grid_dims(k.shape(1), q.shape(0), blocks); // Allocate the intermediates Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); intermediate_shape.push_back(blocks); intermediate_shape.push_back(out.shape().back()); array intermediate(intermediate_shape, q.dtype(), nullptr, {}); intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); sums.set_data(allocator::malloc(sums.nbytes())); maxs.set_data(allocator::malloc(maxs.nbytes())); d.add_temporary(intermediate, s.index); d.add_temporary(sums, s.index); d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); bool bool_mask = has_mask && (*mask).dtype() == bool_; bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, {&has_sinks, MTL::DataType::DataTypeBool, 25}, {&blocks, MTL::DataType::DataTypeInt, 26}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; hash_name += has_sinks ? "_sinks_" : "_nosinks_"; hash_name += std::to_string(blocks); // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); check_kernel_threadgroup_size(kernel, group_dims, hash_name); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(intermediate, 3); compute_encoder.set_output_array(sums, 4); compute_encoder.set_output_array(maxs, 5); compute_encoder.set_bytes(N, 7); compute_encoder.set_bytes(k_head_stride, 8); compute_encoder.set_bytes(k_seq_stride, 9); compute_encoder.set_bytes(v_head_stride, 10); compute_encoder.set_bytes(v_seq_stride, 11); compute_encoder.set_bytes(scale, 12); if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; int32_t head_stride = m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); } if (has_sinks) { compute_encoder.set_input_array(*sinks, 18); } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Final pass kname.clear(); kname = "sdpa_vector_2pass_2_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(v.shape(-1)); // Get the kernel kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(sums, 1); compute_encoder.set_input_array(maxs, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(blocks, 4); // Launch group_dims = MTL::Size(1024, 1, 1); grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); check_kernel_threadgroup_size(kernel, group_dims, kname); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } // namespace bool ScaledDotProductAttention::use_fallback( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, bool do_causal, bool is_training, bool output_logsumexp, Stream s) { if (is_training) { // It's faster for training on Metal to use the unfused SDPA for both // forward and backward. return true; } if (output_logsumexp) { return true; } if (s.device == Device::cpu) { return true; } const int value_head_dim = v.shape(-1); const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); const int key_sequence_length = k.shape(2); const int num_query_heads = q.shape(1); const int num_kv_heads = k.shape(1); const int gqa_factor = num_query_heads / num_kv_heads; const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); const bool supports_sdpa_full = query_sequence_length > 8 && sdpa_full_supported_mask && sdpa_full_supported_head_dim; const bool supports_sdpa_vector = (query_sequence_length <= 8) && (query_sequence_length <= key_sequence_length) && sdpa_vector_supported_head_dim && (query_sequence_length * gqa_factor) <= 32; return !(supports_sdpa_full || supports_sdpa_vector); } bool ScaledDotProductAttention::supports_bool_mask() { return true; } void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; auto& v_pre = inputs[2]; auto& o = outputs[0]; std::vector copies; // Define some copy functions to ensure the layout of the inputs is as // expected. copies.reserve(inputs.size()); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { array arr_copy = contiguous_copy_gpu(arr, s); copies.push_back(std::move(arr_copy)); return copies.back(); } else { return arr; } }; // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; }; std::optional sinks = std::nullopt; if (has_sinks_) { sinks = copy_unless(is_matrix_contiguous, inputs.back()); } bool has_arr_mask = inputs.size() > (3 + has_sinks_); // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { auto q_copy_unless = [](const array& arr) { if (arr.flags().row_contiguous) { return true; } auto& strides = arr.strides(); auto& shape = arr.shape(); if (shape[0] == 1 || shape[1] == 1) { // If either the batch or head dimension is a singleton, the other can // be transposed with the sequence dimension auto bidx = shape[0] == 1 ? 1 : 0; return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && (strides[bidx] == shape[3]); } return false; }; auto kv_copy_unless = [](const array& arr) { // keys and values should be copied if: // - the last dimension is not contiguous // - the batch and head dim are not contiguous auto& strides = arr.strides(); auto& shape = arr.shape(); if (strides.back() != 1) { return false; } if (shape[0] == 1 || shape[1] == 1) { return true; } return (strides[0] == strides[1] * shape[1]); }; bool q_copied = !q_copy_unless(q_pre); array q = (q_copied) ? contiguous_copy_gpu(q_pre, s) : q_pre; const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { if (q_copied) { copies.push_back(q); } o.set_data(allocator::malloc(o.nbytes())); } auto mask_copy_unless = [&q](const array& arr) { auto& strides = arr.strides(); auto& shape = arr.shape(); return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || (strides[0] == strides[1] * shape[1]); }; auto mask = has_arr_mask ? std::optional{copy_unless(mask_copy_unless, inputs[3])} : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long // - The sequence length is even longer and we have gqa bool do_causal = do_causal_ && q.shape(2) > 1; char devc = d.get_architecture().back(); if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } else { sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } } // Full attention mode else { const auto& q = copy_unless(is_matrix_contiguous, q_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); int64_t str_oD = 1; int64_t str_oH = o.shape(3); int64_t str_oL = o.shape(1) * str_oH; int64_t str_oB = o.shape(2) * str_oL; size_t data_size = o.shape(0) * str_oB; array::Flags flags{ /* bool contiguous = */ 1, /* bool row_contiguous = */ 0, /* bool col_contiguous = */ 0, }; o.set_data( allocator::malloc(o.nbytes()), data_size, {str_oB, str_oH, str_oL, str_oD}, flags); auto mask = has_arr_mask ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} : std::nullopt; sdpa_full_self_attention_metal( s, d, q, k, v, scale_, o, do_causal_, mask, sinks); } d.add_temporaries(std::move(copies), s.index); } bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { return true; } void ScaledDotProductAttentionVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("NYI"); } } // namespace mlx::core::fast ================================================ FILE: mlx/backend/metal/scan.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/scan.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { void scan_gpu_inplace( array in, array& out, Scan::ReduceType reduce_type, int axis, bool reverse, bool inclusive, const Stream& s) { auto& d = metal::device(s.device); bool contiguous = in.strides()[axis] == 1; std::string reduce_type_str; switch (reduce_type) { case Scan::Sum: reduce_type_str = "sum"; break; case Scan::Prod: reduce_type_str = "prod"; break; case Scan::Max: reduce_type_str = "max"; break; case Scan::Min: reduce_type_str = "min"; break; case Scan::LogAddExp: reduce_type_str = "logaddexp"; break; } std::string kname; concatenate( kname, contiguous ? "contig_" : "strided_", "scan_", reverse ? "reverse_" : "", inclusive ? "inclusive_" : "exclusive_", reduce_type_str, "_", type_to_name(in), "_", type_to_name(out)); auto kernel = get_scan_kernel(d, kname, reverse, inclusive, reduce_type_str, in, out); if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis); compute_encoder.set_bytes(size, 2); // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; constexpr int simd_size = 32; int elements_per_simd = n_reads * simd_size; int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (size <= n_reads * 1024) { thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) * simd_size; } else if (size <= n_reads * 2048) { thread_group_size = ((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; } thread_group_size = std::min( thread_group_size, static_cast(kernel->maxTotalThreadsPerThreadgroup())); auto tmp_grid_dims = get_2d_grid_dims(in.shape(), in.strides(), /*divisor=*/size); MTL::Size grid_dims( thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis); size_t stride = in.strides()[axis]; int bn = 32; size_t stride_blocks = (stride + bn - 1) / bn; compute_encoder.set_bytes(size, 2); compute_encoder.set_bytes(stride, 3); compute_encoder.set_bytes(stride_blocks, 4); // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; int n_simdgroups = bn / n_reads; int thread_group_size = n_simdgroups * 32; auto tmp_grid_dims = get_2d_grid_dims(in.shape(), in.strides(), /*divisor=*/size * stride); if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) { tmp_grid_dims.width *= stride_blocks; } else { tmp_grid_dims.height *= stride_blocks; } MTL::Size grid_dims( thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } } void Scan::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto in = inputs[0]; if (in.flags().contiguous && in.strides()[axis_] != 0) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); } else { out.set_data( allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { in = contiguous_copy_gpu(in, stream()); out.copy_shared_buffer(in); } scan_gpu_inplace( in, out, reduce_type_, axis_, reverse_, inclusive_, stream()); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/slicing.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" namespace mlx::core { void concatenate_gpu( const std::vector& inputs, array& out, int axis, const Stream& s) { std::vector sizes; sizes.push_back(0); for (auto& p : inputs) { sizes.push_back(p.shape(axis)); } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); auto concurrent_ctx = compute_encoder.start_concurrent(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); } } array compute_dynamic_offset( const array& indices, const Strides& strides, const std::vector& axes, const Stream& s) { auto& d = metal::device(s.device); // Kernel to compute offset here. array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() && (indices.data_size() * indices.itemsize()) >= offset.itemsize(); if (donate) { offset.copy_shared_buffer(indices); } else { offset.set_data(allocator::malloc(offset.itemsize())); } d.add_temporary(offset, s.index); auto dtype = indices.dtype(); std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); auto lib = d.get_library(lib_name, [dtype]() { return fmt::format( R"( [[kernel]] void compute_dynamic_offset_{0}( constant const {1}* indices [[buffer(0)]], device int64_t& offset [[buffer(1)]], constant const int64_t* strides [[buffer(2)]], constant const int* axes [[buffer(3)]], constant const int& n_axes [[buffer(4)]], uint index [[thread_position_in_grid]]) {{ int64_t acc = 0; for (int i = 0; i < n_axes; ++i) {{ acc += indices[i] * strides[axes[i]]; }} offset = acc; }})", type_to_name(dtype), get_type_string(dtype)); }); auto kernel = d.get_kernel(lib_name, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(indices, 0); compute_encoder.set_output_array(offset, 1); compute_encoder.set_vector_bytes(strides, 2); compute_encoder.set_vector_bytes(axes, 3); int n_axes = axes.size(); compute_encoder.set_bytes(n_axes, 4); MTL::Size dims = MTL::Size(1, 1, 1); compute_encoder.dispatch_threads(dims, dims); return offset; } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/softmax.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { constexpr int SOFTMAX_LOOPED_LIMIT = 4096; void Softmax::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[softmax] Does not support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); // Make sure that the last dimension is contiguous auto set_output = [&s, &out](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); } return x; } else { array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } }; const array in = set_output(inputs[0]); int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; const int simd_size = 32; const int n_reads = SOFTMAX_N_READS; const int looped_limit = SOFTMAX_LOOPED_LIMIT; std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; kernel_name += "softmax_"; if (in.dtype() != float32 && precise_) { kernel_name += "precise_"; } kernel_name += type_to_name(out); auto kernel = get_softmax_kernel(d, kernel_name, precise_, out); auto& compute_encoder = d.get_command_encoder(s.index); { MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } else { size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); } compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(axis_size, 2); compute_encoder.dispatch_threads(grid_dims, group_dims); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/sort.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { namespace { void single_block_sort( const Stream& s, metal::Device& d, const array& in, array& out, int axis, int bn, int tn, bool argsort) { // Prepare shapes int n_rows = in.size() / in.shape(axis); auto in_nc_str = in.strides(); in_nc_str.erase(in_nc_str.begin() + axis); auto out_nc_str = out.strides(); out_nc_str.erase(out_nc_str.begin() + axis); auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); int size_sorted_axis = in.shape(axis); int in_stride_sorted_axis = in.strides()[axis]; int out_stride_sorted_axis = out.strides()[axis]; // We can only use the contiguous kernel if the sorted axis // has the largest or smallest stride. // We also need the input to be contiguous bool contiguous = in.flags().contiguous; auto check_strides = [](array x, int sort_stride) { int min_stride = *std::min_element(x.strides().begin(), x.strides().end()); int max_stride = *std::max_element(x.strides().begin(), x.strides().end()); return sort_stride == min_stride || sort_stride == max_stride; }; contiguous &= check_strides(in, in_stride_sorted_axis); contiguous &= check_strides(out, out_stride_sorted_axis); // Prepare kernel name std::ostringstream kname; kname << (contiguous ? "c" : "nc"); if (argsort) { kname << "arg"; } kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out) << "_bn" << bn << "_tn" << tn; auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn); // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); // Set inputs compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(size_sorted_axis, 2); compute_encoder.set_bytes(in_stride_sorted_axis, 3); compute_encoder.set_bytes(out_stride_sorted_axis, 4); if (contiguous) { int in_stride_segment_axis = INT32_MAX; int out_stride_segment_axis = INT32_MAX; for (int i = 0; i < in_nc_str.size(); i++) { if (nc_shape[i] == 1) { continue; } if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { throw std::runtime_error("[Sort::eval_gpu] Stride too large."); } in_stride_segment_axis = std::min(in_stride_segment_axis, static_cast(in_nc_str[i])); out_stride_segment_axis = std::min(out_stride_segment_axis, static_cast(out_nc_str[i])); } compute_encoder.set_bytes(in_stride_segment_axis, 5); compute_encoder.set_bytes(out_stride_segment_axis, 6); } else { compute_encoder.set_bytes(nc_dim, 5); if (nc_shape.empty()) { int shape = 0; int64_t stride = 0; compute_encoder.set_bytes(shape, 6); compute_encoder.set_bytes(stride, 7); compute_encoder.set_bytes(stride, 8); } else { compute_encoder.set_vector_bytes(nc_shape, 6); compute_encoder.set_vector_bytes(in_nc_str, 7); compute_encoder.set_vector_bytes(out_nc_str, 8); } } MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void multi_block_sort( const Stream& s, metal::Device& d, const array& in, array& out, int axis, int bn, int tn, int n_blocks, bool argsort) { // Prepare shapes int n_rows = in.size() / in.shape(axis); auto nc_str = in.strides(); nc_str.erase(nc_str.begin() + axis); auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); if (nc_dim == 0) { nc_shape = {0}; nc_str = {1}; } int size_sorted_axis = in.shape(axis); int stride_sorted_axis = in.strides()[axis]; // Make temporary copies array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {}); array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {}); array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {}); // Do allocations dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes())); dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes())); dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes())); dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes())); block_partitions.set_data(allocator::malloc(block_partitions.nbytes())); std::vector copies = { dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); // Do blockwise sort { std::ostringstream kname; kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_" << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(dev_vals_0, 1); compute_encoder.set_output_array(dev_idxs_0, 2); compute_encoder.set_bytes(size_sorted_axis, 3); compute_encoder.set_bytes(stride_sorted_axis, 4); compute_encoder.set_bytes(nc_dim, 5); compute_encoder.set_vector_bytes(nc_shape, 6); compute_encoder.set_vector_bytes(nc_str, 7); MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do merges bool ping = false; array dev_vals_in = dev_vals_0; array dev_idxs_in = dev_idxs_0; array dev_vals_out = dev_vals_1; array dev_idxs_out = dev_idxs_1; int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024; for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) { dev_vals_in = ping ? dev_vals_1 : dev_vals_0; dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0; dev_vals_out = ping ? dev_vals_0 : dev_vals_1; dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1; ping = !ping; // Do partition { std::ostringstream kname; kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_" << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_output_array(block_partitions, 0); compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_bytes(size_sorted_axis, 3); compute_encoder.set_bytes(merge_tiles, 4); compute_encoder.set_bytes(n_blocks, 5); MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do merge { std::ostringstream kname; kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_" << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(block_partitions, 0); compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_output_array(dev_vals_out, 3); compute_encoder.set_output_array(dev_idxs_out, 4); compute_encoder.set_bytes(size_sorted_axis, 5); compute_encoder.set_bytes(merge_tiles, 6); compute_encoder.set_bytes(n_blocks, 7); MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } // Copy outputs with appropriate strides auto strides = out.strides(); for (int ax = axis + 1; ax < strides.size(); ax++) { strides[ax] *= out.shape(axis); } strides[axis] = 1; copy_gpu_inplace( (argsort) ? dev_idxs_out : dev_vals_out, out, out.shape(), strides, out.strides(), 0, 0, (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, s); d.add_temporaries(std::move(copies), s.index); } void gpu_merge_sort( const Stream& s, metal::Device& d, const array& in, array& out, int axis_, bool argsort) { // Get size info int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); // Get kernel size int tn = 4; int potential_bn = (size_sorted_axis + tn - 1) / tn; int bn; if (potential_bn > 256) { bn = 512; } else if (potential_bn > 128) { bn = 256; } else if (potential_bn > 64) { bn = 128; } else if (potential_bn > 32) { bn = 64; } else { bn = 32; } if (bn == 512 && size_of(in.dtype()) > 4) { bn = 256; } int n_per_block = bn * tn; int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block; if (n_blocks > 1) { return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort); } else { return single_block_sort(s, d, in, out, axis, bn, tn, argsort); } } } // namespace void ArgSort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; gpu_merge_sort(s, d, in, out, axis_, true); } void Sort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; gpu_merge_sort(s, d, in, out, axis_, false); } void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { // We direct arg partition to sort for now assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; gpu_merge_sort(s, d, in, out, axis_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { // We direct partition to sort for now assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; gpu_merge_sort(s, d, in, out, axis_, false); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/ternary.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/ternary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { void ternary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { assert(inputs.size() == 3); auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; TernaryOpType topt = get_ternary_op_type(a, b, c); if (out.size() == 0) { return; } // Try to collapse contiguous dims auto maybe_collapse = [topt, &a, &b, &c, &out]() { if (topt == TernaryOpType::General) { auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); return std::make_tuple( shape, strides[0], strides[1], strides[2], strides[3]); } else { Strides e; return std::make_tuple(Shape{}, e, e, e, e); } }; auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); bool large; auto ndim = shape.size(); int work_per_thread; if (topt == TernaryOpType::General) { large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || c.data_size() > INT32_MAX || out.size() > INT32_MAX; work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; work_per_thread = get_work_per_thread(b.dtype(), out.data_size()); } std::string kernel_name; if (topt == TernaryOpType::General) { kernel_name = "g"; if (shape.size() <= 3) { kernel_name += std::to_string(shape.size()); } else if (work_per_thread > 1) { concatenate(kernel_name, "n", std::to_string(work_per_thread)); } if (large) { kernel_name += "large"; } } else { if (topt == TernaryOpType::VectorScalarVector) { kernel_name = "sv"; } else if (topt == TernaryOpType::VectorVectorScalar) { kernel_name = "vs"; } else { kernel_name = "v"; } if (large) { kernel_name += "2"; } else if (work_per_thread > 1) { kernel_name += "n"; } } concatenate(kernel_name, "_", op, type_to_name(b)); auto& d = metal::device(s.device); auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(c, 2); compute_encoder.set_output_array(out, 3); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (topt == TernaryOpType::General) { // Launch up to 3D grid of threads size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = out.size() / (dim0 * dim1); if (ndim > 3) { compute_encoder.set_vector_bytes(shape, 4); compute_encoder.set_vector_bytes(strides_a, 5); compute_encoder.set_vector_bytes(strides_b, 6); compute_encoder.set_vector_bytes(strides_c, 7); compute_encoder.set_bytes(ndim, 8); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D compute_encoder.set_vector_bytes(strides_a, 4); compute_encoder.set_vector_bytes(strides_b, 5); compute_encoder.set_vector_bytes(strides_c, 6); } if (thread_group_size != 1024) { throw std::runtime_error("[Metal::ternary] Must use 1024 sized block"); } MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { compute_encoder.set_bytes(out.data_size(), 4); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { compute_encoder.set_bytes(out.data_size(), 4); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims); } } void ternary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; TernaryOpType topt = get_ternary_op_type(a, b, c); set_ternary_op_output_data(a, b, c, out, topt); ternary_op_gpu_inplace(inputs, out, op, s); } void ternary_op_gpu( const std::vector& inputs, array& out, const char* op) { auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, op, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { ternary_op_gpu(inputs, out, name()); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/ternary.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void ternary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s); void ternary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/metal/unary.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/unary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ unary_op_gpu(inputs, out, name()); \ } namespace mlx::core { void unary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& in = inputs[0]; bool contig = in.flags().contiguous; if (in.size() == 0) { return; } auto& d = metal::device(s.device); auto maybe_collapse = [contig, &in]() { if (!contig) { return collapse_contiguous_dims(in); } else { return std::make_pair(Shape{}, Strides{}); } }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } int work_per_thread; std::string kernel_name; if (contig) { work_per_thread = get_work_per_thread(in.dtype(), in.data_size()); kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v")); } else { work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; } } concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out)); auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); if (!contig) { // Launch up to 3D grid of threads size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = out.size() / (dim0 * dim1); compute_encoder.set_vector_bytes(shape, 2); compute_encoder.set_vector_bytes(strides, 3); compute_encoder.set_bytes(ndim, 4); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::unary] Must use 1024 sized block"); } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { compute_encoder.set_bytes(in.data_size(), 2); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { compute_encoder.set_bytes(in.data_size(), 2); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims); } } void unary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } void unary_op_gpu( const std::vector& inputs, array& out, const char* op) { auto& s = out.primitive().stream(); unary_op_gpu(inputs, out, op, s); } UNARY_GPU(Abs) UNARY_GPU(ArcCos) UNARY_GPU(ArcCosh) UNARY_GPU(ArcSin) UNARY_GPU(ArcSinh) UNARY_GPU(ArcTan) UNARY_GPU(ArcTanh) UNARY_GPU(BitwiseInvert) UNARY_GPU(Conjugate) UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Imag) UNARY_GPU(Log1p) UNARY_GPU(LogicalNot) UNARY_GPU(Floor) UNARY_GPU(Ceil) UNARY_GPU(Negative) UNARY_GPU(Real) UNARY_GPU(Sigmoid) UNARY_GPU(Sign) UNARY_GPU(Sin) UNARY_GPU(Sinh) UNARY_GPU(Square) UNARY_GPU(Sqrt) UNARY_GPU(Tan) UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { unary_op_gpu(inputs, out, name()); } void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { unary_op_gpu(inputs, out, name()); } else { // No-op integer types out.copy_shared_buffer(in); } } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/unary.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void unary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s); void unary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s); } // namespace mlx::core ================================================ FILE: mlx/backend/metal/utils.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" #include "mlx/backend/common/utils.h" namespace mlx::core { std::string type_to_name(const Dtype& t) { std::string tname; switch (t) { case bool_: tname = "bool_"; break; case uint8: tname = "uint8"; break; case uint16: tname = "uint16"; break; case uint32: tname = "uint32"; break; case uint64: tname = "uint64"; break; case int8: tname = "int8"; break; case int16: tname = "int16"; break; case int32: tname = "int32"; break; case int64: tname = "int64"; break; case float16: tname = "float16"; break; case float32: tname = "float32"; break; case float64: tname = "double"; break; case bfloat16: tname = "bfloat16"; break; case complex64: tname = "complex64"; break; } return tname; } std::string type_to_name(const array& a) { return type_to_name(a.dtype()); } MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) { Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { Dims dims = get_2d_grid_dims_common(shape, strides); return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { Dims dims = get_2d_grid_dims_common(shape, strides, divisor); return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } } // namespace mlx::core ================================================ FILE: mlx/backend/metal/utils.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/backend/metal/device.h" #include "mlx/primitives.h" namespace mlx::core { MLX_API std::string type_to_name(const Dtype& t); MLX_API std::string type_to_name(const array& a); // Compute the grid and block dimensions, check backend/common/utils.h for docs. MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); inline NS::String* make_string(std::ostringstream& os) { std::string string = os.str(); return NS::String::string(string.c_str(), NS::UTF8StringEncoding); } inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { #ifdef MLX_METAL_DEBUG std::ostringstream label; label << "Stream " << index; queue->setLabel(make_string(label)); #endif } inline void debug_set_primitive_buffer_label( MTL::CommandBuffer* command_buffer, Primitive& primitive) { #ifdef MLX_METAL_DEBUG std::ostringstream label; if (auto cbuf_label = command_buffer->label(); cbuf_label) { label << cbuf_label->utf8String(); } label << primitive.name(); command_buffer->setLabel(make_string(label)); #endif } template constexpr bool is_numeric_except_char = std::is_arithmetic_v && !std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v; template void concatenate(std::string& acc, T first) { if constexpr (is_numeric_except_char) { acc += std::to_string(first); } else { acc += first; } } template void concatenate(std::string& acc, T first, Args... args) { if constexpr (is_numeric_except_char) { acc += std::to_string(first); } else { acc += first; } concatenate(acc, args...); } inline int get_work_per_thread(Dtype dtype) { return std::max(1, 8 / dtype.size()); } inline int get_work_per_thread(Dtype dtype, size_t size) { constexpr size_t wpt_threshold = 1 << 16; return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); } inline size_t ceildiv(size_t n, size_t m) { return (n + m - 1) / m; } inline void check_kernel_threadgroup_size( const MTL::ComputePipelineState* kernel, MTL::Size group_dims, const std::string& name) { auto max_size = kernel->maxTotalThreadsPerThreadgroup(); auto requested_size = group_dims.width * group_dims.height * group_dims.depth; if (max_size < requested_size) { std::ostringstream msg; msg << "Maximum threads per threadgroup is " << max_size << " but requested " << requested_size << " for kernel " << name << "."; throw std::runtime_error(msg.str()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/no_cpu/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) ================================================ FILE: mlx/backend/no_cpu/compiled.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/compile_impl.h" #include "mlx/primitives.h" namespace mlx::core { // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is not available so check if the device is a GPU // device. namespace detail { bool compile_available_for_device(const Device& device) { return device == Device::gpu; } } // namespace detail void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error( "[Compiled::eval_cpu] CPU compilation not supported on the platform."); } } // namespace mlx::core ================================================ FILE: mlx/backend/no_cpu/device_info.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/cpu/device_info.h" namespace mlx::core::cpu { bool is_available() { return false; } int device_count() { return 0; } const std::unordered_map>& device_info(int /* device_index */) { static std::unordered_map> empty; return empty; } } // namespace mlx::core::cpu ================================================ FILE: mlx/backend/no_cpu/primitives.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" #include "mlx/fast_primitives.h" #define NO_CPU_MULTI(func) \ void func::eval_cpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no CPU implementation."); \ } #define NO_CPU(func) \ void func::eval_cpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no CPU implementation."); \ } namespace mlx::core { NO_CPU(Abs) NO_CPU(Add) NO_CPU(AddMM) NO_CPU(Arange) NO_CPU(ArcCos) NO_CPU(ArcCosh) NO_CPU(ArcSin) NO_CPU(ArcSinh) NO_CPU(ArcTan) NO_CPU(ArcTan2) NO_CPU(ArcTanh) NO_CPU(ArgPartition) NO_CPU(ArgReduce) NO_CPU(ArgSort) NO_CPU(AsType) NO_CPU(AsStrided) NO_CPU(BitwiseBinary) NO_CPU(BitwiseInvert) NO_CPU(BlockMaskedMM) NO_CPU(Broadcast) NO_CPU(BroadcastAxes) NO_CPU(Ceil) NO_CPU(Cholesky) NO_CPU(Concatenate) NO_CPU(Conjugate) NO_CPU(Contiguous) NO_CPU(Convolution) NO_CPU(Copy) NO_CPU(Cos) NO_CPU(Cosh) NO_CPU_MULTI(CustomTransforms) NO_CPU_MULTI(Depends) NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(DynamicSlice) NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) NO_CPU_MULTI(Eig) NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) NO_CPU(Exp) NO_CPU(ExpandDims) NO_CPU(Expm1) NO_CPU(FFT) NO_CPU(Flatten) NO_CPU(Floor) NO_CPU(Full) NO_CPU(Gather) NO_CPU(GatherAxis) NO_CPU(GatherMM) NO_CPU(GatherQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) NO_CPU(Hadamard) NO_CPU(Imag) NO_CPU(Less) NO_CPU(LessEqual) NO_CPU(Log) NO_CPU(Log1p) NO_CPU(LogicalNot) NO_CPU(LogicalAnd) NO_CPU(LogicalOr) NO_CPU(LogAddExp) NO_CPU(LogSumExp) NO_CPU_MULTI(LUF) NO_CPU(Matmul) NO_CPU(Maximum) NO_CPU(MaskedScatter) NO_CPU(Minimum) NO_CPU(Multiply) NO_CPU(Negative) NO_CPU(NotEqual) NO_CPU(Pad) NO_CPU(Partition) NO_CPU(Power) NO_CPU_MULTI(QRF) NO_CPU(QuantizedMatmul) NO_CPU(QQMatmul) NO_CPU(RandomBits) NO_CPU(Real) NO_CPU(Reduce) NO_CPU(Reshape) NO_CPU(Round) NO_CPU(Scan) NO_CPU(Scatter) NO_CPU(ScatterAxis) NO_CPU(Select) NO_CPU(SegmentedMM) NO_CPU(Sigmoid) NO_CPU(Sign) NO_CPU(Sin) NO_CPU(Sinh) NO_CPU(Slice) NO_CPU(SliceUpdate) NO_CPU(Softmax) NO_CPU(Sort) NO_CPU_MULTI(Split) NO_CPU(Square) NO_CPU(Squeeze) NO_CPU(Sqrt) NO_CPU(StopGradient) NO_CPU(Subtract) NO_CPU_MULTI(SVD) NO_CPU(Tan) NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Unflatten) NO_CPU(Inverse) NO_CPU(View) namespace fast { NO_CPU_MULTI(Quantize) NO_CPU_MULTI(ConvertFP8) } // namespace fast namespace distributed { NO_CPU_MULTI(AllReduce) NO_CPU_MULTI(AllGather) NO_CPU_MULTI(Send) NO_CPU_MULTI(Recv) NO_CPU_MULTI(ReduceScatter) } // namespace distributed } // namespace mlx::core ================================================ FILE: mlx/backend/no_gpu/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) ================================================ FILE: mlx/backend/no_gpu/allocator.cpp ================================================ // Copyright © 2023 Apple Inc. #include #include #include "mlx/allocator.h" #include "mlx/memory.h" #ifdef __APPLE__ #include "mlx/backend/no_gpu/apple_memory.h" #elif defined(__linux__) #include "mlx/backend/no_gpu/linux_memory.h" #else size_t get_memory_size() { return 0; } #endif namespace mlx::core { namespace allocator { class CommonAllocator : public Allocator { /** A general CPU allocator. */ public: virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; size_t get_active_memory() const { return active_memory_; }; size_t get_peak_memory() const { return peak_memory_; }; void reset_peak_memory() { std::unique_lock lk(mutex_); peak_memory_ = 0; }; size_t get_memory_limit() { return memory_limit_; } size_t set_memory_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(memory_limit_, limit); return limit; } private: size_t memory_limit_; size_t active_memory_{0}; size_t peak_memory_{0}; std::mutex mutex_; CommonAllocator() : memory_limit_(0.8 * get_memory_size()) { if (memory_limit_ == 0) { memory_limit_ = 1UL << 33; } }; friend CommonAllocator& common_allocator(); }; CommonAllocator& common_allocator() { static CommonAllocator allocator_; return allocator_; } Allocator& allocator() { return common_allocator(); } void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } return static_cast(ptr_) + 1; } Buffer CommonAllocator::malloc(size_t size) { void* ptr = std::malloc(size + sizeof(size_t)); if (ptr != nullptr) { *static_cast(ptr) = size; } std::unique_lock lk(mutex_); active_memory_ += size; peak_memory_ = std::max(active_memory_, peak_memory_); return Buffer{ptr}; } void CommonAllocator::free(Buffer buffer) { auto sz = size(buffer); std::free(buffer.ptr()); std::unique_lock lk(mutex_); active_memory_ -= sz; } size_t CommonAllocator::size(Buffer buffer) const { if (buffer.ptr() == nullptr) { return 0; } return *static_cast(buffer.ptr()); } } // namespace allocator size_t get_active_memory() { return allocator::common_allocator().get_active_memory(); } size_t get_peak_memory() { return allocator::common_allocator().get_peak_memory(); } void reset_peak_memory() { return allocator::common_allocator().reset_peak_memory(); } size_t set_memory_limit(size_t limit) { return allocator::common_allocator().set_memory_limit(limit); } size_t get_memory_limit() { return allocator::common_allocator().get_memory_limit(); } // No-ops for common allocator size_t get_cache_memory() { return 0; } size_t set_cache_limit(size_t) { return 0; } size_t set_wired_limit(size_t) { return 0; } void clear_cache() {} } // namespace mlx::core ================================================ FILE: mlx/backend/no_gpu/apple_memory.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include namespace { size_t get_memory_size() { size_t memsize = 0; size_t length = sizeof(memsize); sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); return memsize; } } // namespace ================================================ FILE: mlx/backend/no_gpu/device_info.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/backend/gpu/device_info.h" namespace mlx::core::gpu { bool is_available() { return false; } int device_count() { return 0; } const std::unordered_map>& device_info(int /* device_index */) { static std::unordered_map> empty; return empty; } } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/no_gpu/eval.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include "mlx/backend/gpu/device_info.h" #include "mlx/backend/gpu/eval.h" namespace mlx::core::gpu { void new_stream(Stream) {} void eval(array&) { throw std::runtime_error("[gpu::eval] GPU backend is not available"); } void finalize(Stream) { throw std::runtime_error("[gpu::finalize] GPU backend is not available"); } void synchronize(Stream) { throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); } } // namespace mlx::core::gpu ================================================ FILE: mlx/backend/no_gpu/event.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/event.h" #include "mlx/scheduler.h" #include #include namespace mlx::core { struct EventCounter { uint64_t value{0}; std::mutex mtx; std::condition_variable cv; }; Event::Event(Stream stream) : stream_(stream) { auto dtor = [](void* ptr) { delete static_cast(ptr); }; event_ = std::shared_ptr(new EventCounter{}, dtor); } void Event::wait() { auto ec = static_cast(event_.get()); std::unique_lock lk(ec->mtx); if (ec->value >= value()) { return; } ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); } void Event::wait(Stream stream) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } void Event::signal(Stream stream) { scheduler::enqueue(stream, [*this]() mutable { auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); ec->value = value(); } ec->cv.notify_all(); }); } bool Event::is_signaled() const { auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); return (ec->value >= value()); } } } // namespace mlx::core ================================================ FILE: mlx/backend/no_gpu/fence.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include "mlx/fence.h" #include "mlx/scheduler.h" namespace mlx::core { struct FenceImpl { uint32_t count{0}; uint32_t value{0}; std::mutex mtx; std::condition_variable cv; }; Fence::Fence(Stream) { auto dtor = [](void* ptr) { delete static_cast(ptr); }; fence_ = std::shared_ptr(new FenceImpl{}, dtor); } void Fence::wait(Stream stream, const array&) { auto& f = *static_cast(fence_.get()); if (stream.device == Device::cpu) { scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { auto& f = *static_cast(fence_.get()); std::unique_lock lk(f.mtx); if (f.value >= count) { return; } f.cv.wait(lk, [&f, count] { return f.value >= count; }); }); } else { throw std::runtime_error("[Fence::wait] Invalid stream."); } } void Fence::update(Stream stream, const array&, bool) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { auto& f = *static_cast(fence_.get()); std::unique_lock lk(f.mtx); f.value = count; f.cv.notify_all(); }); } else { throw std::runtime_error("[Fence::update] Invalid stream."); } } } // namespace mlx::core ================================================ FILE: mlx/backend/no_gpu/linux_memory.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include namespace { size_t get_memory_size() { struct sysinfo info; if (sysinfo(&info) != 0) { return 0; } size_t total_ram = info.totalram; total_ram *= info.mem_unit; return total_ram; } } // namespace ================================================ FILE: mlx/backend/no_gpu/primitives.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no GPU implementation."); \ } #define NO_GPU_USE_FALLBACK(func) \ bool func::use_fallback(Stream s) { \ return true; \ } \ NO_GPU_MULTI(func) #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no GPU implementation."); \ } namespace mlx::core { bool fast::ScaledDotProductAttention::use_fallback( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, bool do_causal, bool is_training, bool output_logsumexp, Stream s) { return true; } bool fast::ScaledDotProductAttention::supports_bool_mask() { return false; } bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, Stream s) { return true; } NO_GPU(Abs) NO_GPU(Add) NO_GPU(AddMM) NO_GPU(Arange) NO_GPU(ArcCos) NO_GPU(ArcCosh) NO_GPU(ArcSin) NO_GPU(ArcSinh) NO_GPU(ArcTan) NO_GPU(ArcTan2) NO_GPU(ArcTanh) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(BitwiseBinary) NO_GPU(BitwiseInvert) NO_GPU(BlockMaskedMM) NO_GPU(Broadcast) NO_GPU(BroadcastAxes) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) NO_GPU(Concatenate) NO_GPU(Conjugate) NO_GPU(Contiguous) NO_GPU(Convolution) NO_GPU(Copy) NO_GPU(Cos) NO_GPU(Cosh) NO_GPU_MULTI(CustomTransforms) NO_GPU_MULTI(Depends) NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(NumberOfElements) NO_GPU(Remainder) NO_GPU(Equal) NO_GPU(Erf) NO_GPU(ErfInv) NO_GPU(Exp) NO_GPU(ExpandDims) NO_GPU(Expm1) NO_GPU(FFT) NO_GPU(Flatten) NO_GPU(Floor) NO_GPU(Full) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) NO_GPU(Imag) NO_GPU(Less) NO_GPU(LessEqual) NO_GPU(Load) NO_GPU(Log) NO_GPU(Log1p) NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Matmul) NO_GPU(Maximum) NO_GPU(Minimum) NO_GPU(Multiply) NO_GPU(Negative) NO_GPU(NotEqual) NO_GPU(Pad) NO_GPU(Partition) NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(QQMatmul) NO_GPU(RandomBits) NO_GPU(Real) NO_GPU(Reduce) NO_GPU(Reshape) NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SegmentedMM) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) NO_GPU(Sinh) NO_GPU(Slice) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) NO_GPU_MULTI(Split) NO_GPU(Square) NO_GPU(Squeeze) NO_GPU(Sqrt) NO_GPU(StopGradient) NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eig) NO_GPU(View) NO_GPU(MaskedScatter) namespace fast { NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU_MULTI(ScaledDotProductAttention) NO_GPU_MULTI(ScaledDotProductAttentionVJP) NO_GPU_MULTI(ConvertFP8) NO_GPU_MULTI(Quantize) NO_GPU_MULTI(CustomKernel) } // namespace fast namespace distributed { NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) NO_GPU_MULTI(ReduceScatter) } // namespace distributed } // namespace mlx::core ================================================ FILE: mlx/compile.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include #include #include #include #include #include "mlx/allocator.h" #include "mlx/backend/common/compiled.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core { constexpr int max_compile_depth = 11; constexpr int max_compile_arrays = 24; bool is_unary(const Primitive& p) { return ( typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) || typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) || typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) || typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) || typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) || typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) || typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert)); } bool is_binary(const Primitive& p) { return ( typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) || typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) || typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) || typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) || typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) || typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) || typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) || typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) || typeid(p) == typeid(ArcTan2)); } bool is_ternary(const Primitive& p) { return typeid(p) == typeid(Select); } bool is_broadcast(const Primitive& p) { return typeid(p) == typeid(Broadcast); } bool is_noop(const Primitive& p) { return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); } bool is_reduction(const Primitive& p) { return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce); } bool is_fusable(const Primitive& p) { return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p); } Compiled::Compiled( Stream stream, std::vector inputs, std::vector outputs, std::vector tape, std::unordered_set constant_ids) : Primitive(stream), inputs_(std::move(inputs)), outputs_(std::move(outputs)), tape_(std::move(tape)), constant_ids_(std::move(constant_ids)), is_constant_([this](size_t i) { return constant_ids_.find(inputs_[i].id()) != constant_ids_.end(); }) { // Build the kernel name. NodeNamer namer; std::ostringstream os; std::ostringstream constant_hasher; std::unordered_set output_ids; for (auto& o : outputs_) { output_ids.insert(o.id()); } // Fill the input names. This is not really necessary, I just like having A, // B, C, ... as the inputs. for (const auto& x : inputs_) { namer.get_name(x); } // The primitives describing the tape. For unary and binary primitives this // must be enough to describe the full computation. for (const auto& a : tape_) { // name and type of output os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); // whether or not it's an output if (output_ids.find(a.id()) != output_ids.end()) { os << "O"; } else { os << "I"; } // computation performed os << a.primitive().name(); // name of inputs to the function for (auto& inp : a.inputs()) { os << namer.get_name(inp); } } os << "_"; for (const auto& x : inputs_) { if (constant_ids_.find(x.id()) != constant_ids_.end()) { os << "C"; print_constant(constant_hasher, x); } else { os << (is_scalar(x) ? "S" : "V"); } } os << "_"; for (const auto& x : inputs) { if (constant_ids.find(x.id()) != constant_ids.end()) { continue; } os << kindof(x.dtype()) << x.itemsize(); } os << "_" << std::hash{}(constant_hasher.str()); kernel_lib_ = os.str(); } std::vector Compiled::vjp( const std::vector&, const std::vector&, const std::vector&, const std::vector&) { throw std::runtime_error("[Compiled] Cannot vjp primitive."); } std::vector Compiled::jvp( const std::vector&, const std::vector&, const std::vector&) { throw std::runtime_error("[Compiled] Cannot jvp primitive."); } std::pair, std::vector> Compiled::vmap( const std::vector&, const std::vector&) { throw std::runtime_error("[Compiled] Cannot vmap primitive."); } bool Compiled::is_equivalent(const Primitive& other) const { const Compiled& a_other = static_cast(other); return std::equal( tape_.begin(), tape_.end(), a_other.tape_.begin(), a_other.tape_.end(), [](const array& a1, const array& a2) { auto& p1 = a1.primitive(); auto& p2 = a2.primitive(); return typeid(p1) == typeid(p2) && p1.is_equivalent(p2); }); } const char* Compiled::name() const { if (name_.empty()) { std::ostringstream os; os << "Compiled"; for (auto& a : tape_) { os << a.primitive().name(); } name_ = os.str(); } return name_.c_str(); } std::vector Compiled::output_shapes(const std::vector& inputs) { size_t nd = 0; for (auto& in : inputs) { nd = std::max(nd, in.ndim()); } Shape out_shape(nd, 0); for (auto& in : inputs) { auto dd = nd - in.ndim(); for (auto i = dd; i < nd; ++i) { out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]); } } // All outputs have the same shape return std::vector(outputs_.size(), out_shape); } namespace detail { std::atomic& compile_mode() { auto get_val = []() { if (std::getenv("MLX_DISABLE_COMPILE")) { return CompileMode::disabled; } else { return CompileMode::enabled; } }; static std::atomic compile_mode_ = get_val(); return compile_mode_; } // Helper like below but only merges the two provided arrays. If the src has // siblings then these won't be merged to the dst. void merge_one(array& dst, array& src, ParentsMap& parents_map) { auto src_parents = parents_map.find(src.id()); if (src_parents == parents_map.end()) { return; } auto& pairs = parents_map[dst.id()]; for (auto& parent : src_parents->second) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } // If src is a parent of dst, remove it from dst's parents for (auto it = pairs.begin(); it != pairs.end();) { if (it->first.id() == src.id()) { it = pairs.erase(it); } else { it++; } } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } // Helper that merges two arrays in the graph by setting the parents of the // source to point to the destination. The arrays are assumed to be coming from // equivalent primitives so their siblings are merged as well. void merge(array& dst, array& src, ParentsMap& parents_map) { // Canonicalize the order of the primitives outputs auto sources = src.outputs(); auto dests = dst.outputs(); // For each src parent, point it to the corresponding dst for (int i = 0; i < sources.size(); ++i) { merge_one(dests[i], sources[i], parents_map); } } // Any parent in the divider will continue to refer to `x` but any parent not // in the divider will refer to a copy of the operation. array split_one( const array& x, ParentsMap& parents_map, const std::unordered_set& divider) { array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs()); auto& x_parents = parents_map[x.id()]; auto& y_parents = parents_map[y.id()]; for (auto it = x_parents.begin(); it != x_parents.end();) { if (divider.find(it->first.id()) != divider.end()) { it->first.inputs()[it->second] = y; y_parents.emplace_back(std::move(*it)); it = x_parents.erase(it); } else { it++; } } return y; } template std::uintptr_t get_function_address(const std::function& fun) { using FunType = T (*)(U...); const FunType* fun_ptr = fun.template target(); if (fun_ptr == nullptr) { return 0; } return reinterpret_cast(*fun_ptr); } class CompilerCache { public: struct CacheEntry { CacheEntry(Stream stream, bool shapeless) : stream(stream), shapeless(shapeless) {}; Stream stream; bool shapeless; std::vector inputs; std::vector outputs; std::vector tape; bool empty{true}; std::vector constants; std::shared_ptr extra; }; // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs CacheEntry& find( std::uintptr_t fun_id, const std::vector& inputs, bool shapeless, const std::vector& constants) { // Find the cache entries for |fun_id|. std::vector& entries = cache_[fun_id]; // Compare if 2 arrays have same shape and dtype. auto has_same_shape_and_dtype = [shapeless]( const std::vector& in1, const std::vector& in2) { if (in1.size() != in2.size()) { return false; } for (size_t i = 0; i < in1.size(); ++i) { if (in1[i].ndim() != in2[i].ndim()) { return false; } if (!shapeless && in1[i].shape() != in2[i].shape()) { return false; } if (in1[i].dtype() != in2[i].dtype()) { return false; } } return true; }; // Loop over entries and check: // - Default stream and device match the entry's default stream // - Inputs match i.e. shapes and types must be equal. auto stream = default_stream(default_device()); for (CacheEntry& entry : entries) { // Check that the default stream and device match if (entry.stream != stream) { continue; } if (entry.shapeless != shapeless) { continue; } // Check the inputs match and return if so if (has_same_shape_and_dtype(inputs, entry.inputs) && constants == entry.constants) { return entry; } } // Otherwise append a new cache entry entries.push_back(CacheEntry{stream, shapeless}); return entries.back(); } void erase(std::uintptr_t fun_id) { cache_.erase(fun_id); } void clear() { cache_.clear(); } private: CompilerCache() { // Make sure the allocator is fully // initialized before the compiler cache allocator::allocator(); } friend CompilerCache& compiler_cache(); std::unordered_map> cache_; }; CompilerCache& compiler_cache() { static thread_local CompilerCache compiler_cache_; return compiler_cache_; } std::tuple, std::vector, std::shared_ptr> compile_trace( const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless) { // Set the global tracing flag. detail::InTracing in_tracing{shapeless}; // Run the function on placeholder inputs // to get compute graph std::vector tracer_inputs; for (int i = 0; i < inputs.size(); ++i) { array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {}); in.set_tracer(true); tracer_inputs.push_back(std::move(in)); } auto output = fun(tracer_inputs); return {tracer_inputs, output.first, output.second}; } // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, std::vector& outputs, const std::vector& original_inputs) { std::vector tape; std::unordered_map>> parents_map; { std::function recurse; std::unordered_set input_set; std::unordered_set original_input_set; for (int i = 0; i < inputs.size(); ++i) { input_set.insert(inputs[i].id()); original_input_set.insert(original_inputs[i].id()); } // DFS the graph to build the tape, and log parents and scalars std::unordered_set cache; recurse = [&](const array& a) { auto id = a.id(); if (original_input_set.find(id) != original_input_set.end()) { throw std::invalid_argument( "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); } if (cache.find(id) != cache.end()) { return; } for (int i = 0; i < a.inputs().size(); i++) { auto& in = a.inputs()[i]; parents_map[in.id()].push_back({a, i}); for (auto& s : a.siblings()) { parents_map[in.id()].push_back({s, i}); } // Don't recurse on inputs (but add them to the tape for the purpose // of future optimizations) if (input_set.find(a.id()) == input_set.end()) { recurse(in); } } cache.insert(id); for (auto& s : a.siblings()) { cache.insert(s.id()); } tape.push_back(a); }; for (auto& a : outputs) { recurse(a); } } // Deep copy the tape and parents map while preserving inputs and outputs std::vector new_tape; std::unordered_set io_set; std::unordered_map old_to_new; for (auto& o : outputs) { old_to_new.insert({o.id(), o}); io_set.insert(o.id()); for (auto& s : o.siblings()) { old_to_new.insert({s.id(), s}); io_set.insert(s.id()); } } for (auto& i : inputs) { io_set.insert(i.id()); old_to_new.insert({i.id(), i}); } new_tape.reserve(tape.size()); for (auto& arr : tape) { if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) { old_to_new.insert({arr.id(), arr}); new_tape.push_back(arr); continue; } std::vector inputs; inputs.reserve(arr.inputs().size()); for (auto& i : arr.inputs()) { inputs.push_back(old_to_new.find(i.id())->second); } if (arr.siblings().size() > 0) { std::vector types; std::vector shapes; auto out = arr.outputs(); for (auto& o : out) { types.push_back(o.dtype()); shapes.push_back(o.shape()); } auto as = array::make_arrays( std::move(shapes), types, arr.primitive_ptr(), std::move(inputs)); for (int i = 0; i < out.size(); ++i) { old_to_new.insert({out[i].id(), as[i]}); } new_tape.push_back(as[arr.sibling_position()]); } else { auto a = array( arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); old_to_new.insert({arr.id(), a}); new_tape.push_back(a); } } io_set.clear(); for (auto& o : outputs) { if (!(io_set.insert(o.id()).second)) { continue; } for (auto& i : o.inputs()) { i = old_to_new.find(i.id())->second; } for (auto& s : o.siblings()) { io_set.insert(s.id()); for (auto& i : s.inputs()) { i = old_to_new.find(i.id())->second; } } } tape = std::move(new_tape); std::unordered_map>> new_parents_map; for (auto& [id, vec] : parents_map) { for (auto& [a, _] : vec) { a = old_to_new.find(a.id())->second; } new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec); } parents_map = std::move(new_parents_map); return {tape, parents_map}; } static inline uint64_t splitmix64(uint64_t x) noexcept { x += 0x9e3779b97f4a7c15ull; x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull; x = (x ^ (x >> 27)) * 0x94d049bb133111ebull; return x ^ (x >> 31); } struct VecU64Hash { size_t operator()(const std::vector& s) const noexcept { uint64_t h = 0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull; for (uint64_t x : s) { h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull)); } return (size_t)h; } }; // Simplify the tape. Note, this function modifies in-place both the tape, // the parents map to remove orphaned arrays, and potentially the outputs void compile_simplify( std::vector& tape, ParentsMap& parents_map, std::vector& outputs, int passes) { // Helpers to identify identical scalars std::map, array> scalars; auto is_scalar = [](const array& a) { // Condition for when it's safe to read an array return a.is_available() && a.ndim() == 0; }; auto get_scalar_rep = [](const array& a) { uint64_t v = 0; switch (a.dtype().size()) { case 1: v = *a.data(); break; case 2: v = *a.data(); break; case 4: v = *a.data(); break; case 8: v = *a.data(); break; } return std::make_pair(v, a.dtype().val()); }; for (auto& a : tape) { if (is_scalar(a)) { scalars.insert({get_scalar_rep(a), a}); } } // Depth-1 array equivalence check. auto array_equivalent = [](const array& a, const array& b) { if (!a.has_primitive() || !b.has_primitive()) { return false; } if (a.primitive_id() == b.primitive_id()) { return false; } const auto& pa = a.primitive(); const auto& pb = b.primitive(); if (typeid(pa) != typeid(pb)) { return false; } if (a.inputs().size() != b.inputs().size()) { return false; } for (int i = 0; i < a.inputs().size(); i++) { if (a.inputs()[i].id() != b.inputs()[i].id()) { return false; } } return pa.is_equivalent(pb); }; // Merge scalars std::vector new_tape; for (auto& arr : tape) { // Check if we can merge scalars if (is_scalar(arr)) { auto scalar = scalars.find(get_scalar_rep(arr)); if (scalar->second.id() != arr.id()) { merge(scalar->second, arr, parents_map); // Don't keep orphaned scalars in the tape continue; } } new_tape.push_back(std::move(arr)); } tape = std::move(new_tape); // Remove no-ops { std::unordered_map output_map; for (auto& o : outputs) { output_map.insert({o.id(), o}); } for (auto& arr : tape) { if (!arr.has_primitive() || !is_noop(arr.primitive())) { new_tape.push_back(std::move(arr)); continue; } merge_one(arr.inputs()[0], arr, parents_map); if (auto it = output_map.find(arr.id()); it != output_map.end()) { it->second = arr.inputs()[0]; } } tape = std::move(new_tape); for (auto& o : outputs) { o = output_map.at(o.id()); } } std::unordered_map tape_order; for (uint32_t i = 0; i < tape.size(); ++i) { tape_order.insert({tape[i].id(), i}); } std::unordered_set output_set; for (auto& o : outputs) { output_set.insert(o.id()); } // Multi-pass merge only keeping non-orphaned arrays in the tape for (int pass = 0; pass < passes; ++pass) { for (auto& arr : tape) { // Helper to check if we can merge the parents of the // given array auto maybe_merge_parents = [&](auto& a) { auto parents = parents_map.find(a.id()); if (parents != parents_map.end()) { auto N = parents->second.size(); std::vector mask(N, false); auto try_merge = [&](int dst_idx, int src_idx) { if (tape_order[parents->second[src_idx].first.id()] < tape_order[parents->second[dst_idx].first.id()]) { std::swap(src_idx, dst_idx); } auto& src = parents->second[src_idx].first; auto& dst = parents->second[dst_idx].first; if (src.id() != dst.id() && array_equivalent(src, dst) && output_set.find(src.id()) == output_set.end()) { merge(dst, src, parents_map); mask[src_idx] = true; } }; if (N > 100) { std::unordered_map< std::vector, std::vector, VecU64Hash> dst_map; // Find possibly mergeable groups for (int i = 0; i < N; i++) { // Make the hash key std::vector key; auto& curr = parents->second[i].first; key.reserve(curr.inputs().size() + 2); for (auto& in : curr.inputs()) { key.push_back(in.id()); } auto& p = curr.primitive(); key.push_back(curr.inputs().size()); key.push_back(typeid(p).hash_code()); auto it = dst_map.find(key); if (it == dst_map.end()) { bool _; std::tie(it, _) = dst_map.insert({key, std::vector{}}); } it->second.push_back(i); } for (auto& [_, group] : dst_map) { for (int i = 0; i < group.size(); ++i) { if (mask[group[i]]) { continue; } for (int j = i + 1; j < group.size(); ++j) { if (mask[group[j]]) { continue; } try_merge(group[i], group[j]); } } } } else { for (int i = 0; i < N; ++i) { if (mask[i]) { continue; } for (int j = i + 1; j < N; ++j) { if (mask[j]) { continue; } try_merge(i, j); } } } // Erase orphaned parents so we don't keep fusing with them for (int i = N - 1; i >= 0; --i) { if (mask[i]) { parents->second.erase(parents->second.begin() + i); } } return false; } else { return output_set.find(a.id()) == output_set.end(); } }; bool discard = maybe_merge_parents(arr); for (auto& s : arr.siblings()) { discard &= maybe_merge_parents(s); } // If an array and its siblings have no parents, and none of them are // outputs, it is safe to remove it from the tape if (!discard) { new_tape.push_back(std::move(arr)); } } tape = std::move(new_tape); } } // Extract sub-graphs of the graph that can be compiled // and replace them with a Compiled Primitive. void compile_fuse( std::vector& tape, ParentsMap& parents_map, const std::vector& inputs, std::vector& outputs) { // Track outputs to replace with new compiled outputs std::unordered_map output_map; for (auto& o : outputs) { output_map.insert({o.id(), o}); } // Set of inputs to distinguish constants std::unordered_set input_ids; for (auto& in : inputs) { input_ids.insert(in.id()); } // Go through the tape in reverse order and check for fusable sub-graphs std::vector new_tape; std::unordered_set global_cache; for (int i = tape.size() - 1; i >= 0; --i) { auto& arr = tape[i]; // Already compiled if (global_cache.find(arr.id()) != global_cache.end()) { continue; } // Two pass recursion: // First pass: // - Collect all the primitives which we can fuse with // - Keeps a cache of fusable primitives which may be added out of // DAG order. We have to determine if all of a fused primitive's // outputs are also in the fused section, and this may not be the // case the first time we visit it. // Second pass: // - Collect inputs to the new compiled primitive // - Add fusable primitives to a tape in the correct order std::function recurse; std::unordered_set cache; std::unordered_set input_set; recurse = [&](const array& a, int depth, const Stream& s, const Shape& shape) { if (cache.find(a.id()) != cache.end()) { return; } // Stop fusing if: // - Depth limit exceeded // - Constant input // - Stream mismatch // - Non fusable primitive // - Is global output but has a different shape if (depth >= max_compile_depth || !a.has_primitive() || a.primitive().stream() != s || !is_fusable(a.primitive()) || (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) { // Possible input input_set.insert(a.id()); return; } bool all_parents_in = true; if (depth > 0) { // Guaranteed to have a parent since nested in the // recursion. auto& parents = parents_map.at(a.id()); for (auto& [p, idx] : parents) { auto in_cache = cache.find(p.id()) != cache.end(); if (!in_cache) { all_parents_in = false; break; } } } // Arrays with a mix of parents outside the compilable section // are not fusable except for broadcast which we can split to avoid // stopping fusion if (!all_parents_in) { if (a.has_primitive() && is_broadcast(a.primitive()) && input_set.size() < max_compile_arrays) { array b = split_one(a, parents_map, cache); recurse(b, depth, s, shape); } else { // Possible input input_set.insert(a.id()); } return; } if (output_map.find(a.id()) != output_map.end()) { input_set.insert(a.id()); } else { // Not an input anymore since fusing it input_set.erase(a.id()); } if (input_set.size() >= max_compile_arrays) { return; } cache.insert({a.id()}); for (auto& in : a.inputs()) { recurse(in, depth + 1, s, shape); } }; // This will be the result of the fused operation so it needs // a) to not be already computed ie have a primitive // b) that primitive to not be a broadcast since it will unnecessarily // cast to a contiguous array potentially blowing up memory if (arr.has_primitive() && !is_broadcast(arr.primitive())) { Stream s = arr.primitive().stream(); recurse(arr, 0, s, arr.shape()); } // Not worth fusing a single primitive if (cache.size() <= 1) { new_tape.push_back(arr); continue; } // Recurse a second time to build the tape in the right // order and collect the inputs input_set.clear(); std::vector inputs; std::vector fused_tape; std::unordered_set tape_set; std::function recurse_tape; recurse_tape = [&](const array& a) { if (cache.find(a.id()) == cache.end()) { if (input_set.find(a.id()) == input_set.end()) { input_set.insert(a.id()); inputs.push_back(a); } return; } if (tape_set.find(a.id()) != tape_set.end()) { return; } tape_set.insert(a.id()); for (auto& in : a.inputs()) { recurse_tape(in); } fused_tape.push_back(a); }; recurse_tape(arr); std::vector old_outputs; // Add to global cache and add any global outputs to outputs // of new primitive for (int j = 0; j < fused_tape.size() - 1; ++j) { auto& f = fused_tape[j]; if (output_map.find(f.id()) != output_map.end()) { old_outputs.push_back(f); // Parents are now siblings, update the parent map auto& pairs = parents_map[f.id()]; pairs.erase( std::remove_if( pairs.begin(), pairs.end(), [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }), pairs.end()); } else { // Remove inner fused arrays parents from the parents map // to keep the parents map in a valid state parents_map.erase(f.id()); } global_cache.insert({f.id()}); } old_outputs.push_back(arr); std::vector shapes; std::vector types; for (auto& o : old_outputs) { if (o.shape() != old_outputs.back().shape()) { throw std::runtime_error( "[compile] Compilation failed. Tried to fuse operations with different output shapes"); } shapes.push_back(o.shape()); types.push_back(o.dtype()); } std::unordered_set constant_ids; for (auto& in : inputs) { // Scalar constant if (in.size() == 1 && !in.has_primitive() && input_ids.find(in.id()) == input_ids.end()) { constant_ids.insert(in.id()); } } auto compiled_outputs = array::make_arrays( std::move(shapes), types, std::make_shared( old_outputs.back().primitive().stream(), inputs, old_outputs, std::move(fused_tape), std::move(constant_ids)), inputs); // One output per primitive new_tape.push_back(compiled_outputs.back()); // Replace inputs old parents with compiled_outputs for (int i = 0; i < inputs.size(); ++i) { auto& pairs = parents_map[inputs[i].id()]; pairs.erase( std::remove_if( pairs.begin(), pairs.end(), [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }), pairs.end()); for (auto& o : compiled_outputs) { pairs.push_back({o, i}); } } // - Update outputs parents to point to compiled outputs // - Update any overall graph outputs to be compiled outputs for (int o = 0; o < old_outputs.size(); ++o) { merge_one(compiled_outputs[o], old_outputs[o], parents_map); if (auto it = output_map.find(old_outputs[o].id()); it != output_map.end()) { it->second = compiled_outputs[o]; } } } std::reverse(new_tape.begin(), new_tape.end()); tape = std::move(new_tape); // Replace output with potentially compiled output for (auto& o : outputs) { o = output_map.at(o.id()); } } std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, const std::vector& inputs, bool shapeless) { std::unordered_map trace_to_real; for (int i = 0; i < inputs.size(); ++i) { trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); } auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); }; for (auto& a : tape) { // Arrays in the tape without primitives are either: // - inputs, which are already in the map // - constants, which can be used directly // - a load primitive which has no inputs and will become a constant // after the first eval if (!a.has_primitive() || is_load(a.primitive())) { trace_to_real.insert({a.id(), a}); } else { // Find real inputs std::vector real_inputs; for (auto& in : a.inputs()) { real_inputs.push_back(trace_to_real.at(in.id())); } if (a.siblings().empty()) { auto shape = shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape(); auto real_a = array( std::move(shape), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); trace_to_real.insert({a.id(), std::move(real_a)}); } else { // Ensure the order is correct for multi-output primitives std::vector types; auto trace_out = a.outputs(); for (auto& o : trace_out) { types.push_back(o.dtype()); } std::vector shapes; if (shapeless) { shapes = a.primitive().output_shapes(real_inputs); } else { for (auto& o : trace_out) { shapes.push_back(o.shape()); } } auto real_out = array::make_arrays( std::move(shapes), types, a.primitive_ptr(), real_inputs); for (int i = 0; i < trace_out.size(); ++i) { trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); } } } } std::vector outputs; for (auto& o : trace_outputs) { outputs.push_back(trace_to_real.at(o.id())); } return outputs; } bool skip_compile() { return compile_mode() == CompileMode::disabled || !(compile_available_for_device(default_device())); } ArrayFnWithExtra compile( ArrayFnWithExtra fun, std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { if (skip_compile()) { return fun; } if (!fun) { throw std::invalid_argument( "[compile] Cannot compile a function without a target."); } return [fun = std::move(fun), fun_id, shapeless, constants = std::move(constants)](const std::vector& inputs) { // If the inputs are tracers, trace the original graph if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { return in.is_tracer(); })) { return fun(inputs); } // Find a cache entry with the correct inputs auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants); // No matching cache entry existed, so compile if (entry.empty) { // Mark the entry as not empty since we are about to fill it entry.empty = false; // Set the constants entry.constants = std::move(constants); // Trace to build the graph std::tie(entry.inputs, entry.outputs, entry.extra) = compile_trace(fun, inputs, shapeless); // DFS the graph and get a tape, and a map of array id to (parent, // position in parent inputs) std::unordered_map>> parents_map; std::tie(entry.tape, parents_map) = compile_dfs(entry.inputs, entry.outputs, inputs); // Simplify the tape auto mode = compile_mode().load(); if (mode != CompileMode::no_simplify) { compile_simplify( entry.tape, parents_map, entry.outputs, /* passes */ 3); } // Kernel fusion to generate Compiled primitives. The tape and // new outputs must be updated accordingly if (mode != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } } // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated return ArraysAndExtra{ compile_replace( entry.tape, entry.inputs, entry.outputs, inputs, shapeless), entry.extra}; }; } std::function(const std::vector&)> compile( std::function(const std::vector&)> fun, std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { if (skip_compile()) { return fun; } if (!fun) { throw std::invalid_argument( "[compile] Cannot compile a function without a target."); } ArrayFnWithExtra fun_with_extra = [fun = std::move(fun)](const std::vector& inputs) { return ArraysAndExtra{fun(inputs), nullptr}; }; auto compiled_fun = compile( std::move(fun_with_extra), fun_id, shapeless, std::move(constants)); return [compiled_fun = std::move(compiled_fun)](const std::vector& inputs) { return compiled_fun(inputs).first; }; } void compile_erase(std::uintptr_t fun_id) { detail::compiler_cache().erase(fun_id); } void compile_clear_cache() { detail::compiler_cache().clear(); } } // namespace detail std::function(const std::vector&)> compile( std::function(const std::vector&)> fun, bool shapeless /* false */) { if (detail::skip_compile()) { return fun; } auto fun_id = detail::get_function_address(fun); if (fun_id) { // If the function has an addressable target then no need to manage it's // lifetime return detail::compile(std::move(fun), fun_id, shapeless); } else { auto pfun = std::shared_ptr< std::function(const std::vector&)>>( new std::function(const std::vector&)>{fun}, [](auto* p) { detail::compile_erase(reinterpret_cast(p)); delete p; }); fun_id = reinterpret_cast(pfun.get()); return detail::compile( [pfun = std::move(pfun)](const auto& inputs) { return (*pfun)(inputs); }, fun_id, shapeless); } } std::function(const std::vector&)> compile( std::vector (*fun)(const std::vector&), bool shapeless /* = false */) { if (detail::skip_compile()) { return fun; } return detail::compile(fun, reinterpret_cast(fun), shapeless); } void disable_compile() { detail::compile_mode() = CompileMode::disabled; } void enable_compile() { detail::compile_mode() = CompileMode::enabled; } void set_compile_mode(CompileMode mode) { detail::compile_mode() = mode; } } // namespace mlx::core ================================================ FILE: mlx/compile.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/api.h" #include "mlx/array.h" namespace mlx::core { enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; /** Compile takes a function and returns a compiled function. */ MLX_API std::function(const std::vector&)> compile( std::function(const std::vector&)> fun, bool shapeless = false); MLX_API std::function(const std::vector&)> compile( std::vector (*fun)(const std::vector&), bool shapeless = false); // Convert capture-less lambdas to function pointers. template < typename F, typename = std::enable_if_t< std::is_convertible_v())>>> std::function(const std::vector&)> compile( F&& f, bool shapeless = false) { return compile(+f, shapeless); } /** Globally disable compilation. * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also * be used to disable compilation. */ MLX_API void disable_compile(); /** Globally enable compilation. * This will override the environment variable ``MLX_DISABLE_COMPILE``. */ MLX_API void enable_compile(); /** Set the compiler mode to the given value. */ MLX_API void set_compile_mode(CompileMode mode); } // namespace mlx::core ================================================ FILE: mlx/compile_impl.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/api.h" #include "mlx/array.h" namespace mlx::core::detail { using ArraysAndExtra = std::pair, std::shared_ptr>; using ArrayFnWithExtra = std::function&)>; // This is not part of the general C++ API as calling with a bad id is a bad // idea. MLX_API std::function(const std::vector&)> compile( std::function(const std::vector&)> fun, std::uintptr_t fun_id, bool shapeless = false, std::vector constants = {}); MLX_API ArrayFnWithExtra compile( ArrayFnWithExtra fun, std::uintptr_t fun_id, bool shapeless, std::vector constants); // Erase cached compile functions MLX_API void compile_erase(std::uintptr_t fun_id); // Clear the compiler cache causing a recompilation of all compiled functions // when called again. MLX_API void compile_clear_cache(); bool compile_available_for_device(const Device& device); std::tuple, std::vector, std::shared_ptr> compile_trace( const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless); using ParentsMap = std::unordered_map>>; // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, std::vector& outputs, const std::vector& original_inputs); // Simplify the tape. void compile_simplify( std::vector& tape, ParentsMap& parents_map, std::vector& outputs, int passes); std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, const std::vector& inputs, bool shapeless); void compile_validate_shapeless(const std::vector& tape); } // namespace mlx::core::detail ================================================ FILE: mlx/device.cpp ================================================ // Copyright © 2023-2026 Apple Inc. #include #include "mlx/backend/cpu/device_info.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/device.h" namespace mlx::core { Device& mutable_default_device() { static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; return default_device; } const Device& default_device() { return mutable_default_device(); } void set_default_device(const Device& d) { if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { return lhs.type == rhs.type && lhs.index == rhs.index; } bool operator!=(const Device& lhs, const Device& rhs) { return !(lhs == rhs); } bool is_available(const Device& d) { switch (d.type) { case Device::cpu: return cpu::is_available() && (d.index < cpu::device_count()); case Device::gpu: return gpu::is_available() && (d.index < gpu::device_count()); } // appease compiler return false; } int device_count(Device::DeviceType type) { switch (type) { case Device::cpu: return cpu::device_count(); case Device::gpu: return gpu::device_count(); } // appease compiler return 0; } const std::unordered_map>& device_info(const Device& d) { switch (d.type) { case Device::cpu: return cpu::device_info(d.index); case Device::gpu: return gpu::device_info(d.index); } // appease compiler static std::unordered_map> empty; return empty; } } // namespace mlx::core ================================================ FILE: mlx/device.h ================================================ // Copyright © 2023-2025 Apple Inc. #pragma once #include "mlx/api.h" #include #include #include namespace mlx::core { struct MLX_API Device { enum class DeviceType { cpu, gpu, }; static constexpr DeviceType cpu = DeviceType::cpu; static constexpr DeviceType gpu = DeviceType::gpu; Device(DeviceType type, int index = 0) : type(type), index(index) {} DeviceType type; int index; }; MLX_API const Device& default_device(); MLX_API void set_default_device(const Device& d); MLX_API bool operator==(const Device& lhs, const Device& rhs); MLX_API bool operator!=(const Device& lhs, const Device& rhs); MLX_API bool is_available(const Device& d); /** Get the number of available devices for the given device type. */ MLX_API int device_count(Device::DeviceType type); /** * Get information about a device. * * Returns a map of device properties. Keys vary by backend: * - device_name (string): Device name * - architecture (string): Architecture identifier * - total_memory/memory_size (size_t): Total device memory * - free_memory (size_t): Available memory (CUDA only) * - uuid (string): Device UUID (CUDA only) * - pci_bus_id (string): PCI bus ID (CUDA only) * - compute_capability_major/minor (size_t): Compute capability (CUDA only) */ MLX_API const std::unordered_map>& device_info(const Device& d = default_device()); } // namespace mlx::core ================================================ FILE: mlx/distributed/CMakeLists.txt ================================================ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) if(MLX_BUILD_CPU AND NOT WIN32) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl) ================================================ FILE: mlx/distributed/distributed.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/cuda/cuda.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/jaccl/jaccl.h" #include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/ring/ring.h" namespace mlx::core::distributed { namespace detail { Stream communication_stream(Group group, StreamOrDevice s /* = {} */) { return group.raw_group()->communication_stream(s); } void all_sum(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_sum(input, output, stream); } void all_max(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_max(input, output, stream); } void all_min(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_min(input, output, stream); } void all_gather(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_gather(input, output, stream); } void send(Group group, const array& input, int dst, Stream stream) { group.raw_group()->send(input, dst, stream); } void recv(Group group, array& out, int src, Stream stream) { group.raw_group()->recv(out, src, stream); } void sum_scatter( Group group, const array& input, array& output, Stream stream) { group.raw_group()->sum_scatter(input, output, stream); } class EmptyGroup : public GroupImpl { public: Stream communication_stream(StreamOrDevice s) override { return to_stream(s); } int rank() override { return 0; } int size() override { return 1; } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("Cannot split the distributed group further."); } void all_sum(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void all_gather(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void send(const array&, int, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void recv(array&, int, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void all_max(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void all_min(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } void sum_scatter(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } }; } // namespace detail bool is_available() { return mpi::is_available() || ring::is_available() || nccl::is_available() || jaccl::is_available(); } bool is_available(const std::string& bk) { if (bk == "any") { return is_available(); } if (bk == "mpi") { return mpi::is_available(); } if (bk == "ring") { return ring::is_available(); } if (bk == "nccl") { return nccl::is_available(); } if (bk == "jaccl") { return jaccl::is_available(); } return false; } int Group::rank() const { return group_->rank(); } int Group::size() const { return group_->size(); } Group Group::split(int color, int key /* = -1 */) const { return Group(group_->split(color, key)); } Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { static std::unordered_map> backends; // Already initialized so return the group. if (auto g = backends.find(bk); g != backends.end()) { return Group(g->second); } // Create the requested communication group std::shared_ptr group{nullptr}; std::string bk_ = bk; if (bk == "mpi") { group = mpi::init(strict); } else if (bk == "ring") { group = ring::init(strict); } else if (bk == "nccl") { group = nccl::init(strict); } else if (bk == "jaccl") { group = jaccl::init(strict); } else if (bk == "any") { if (mlx::core::cu::is_available()) { group = nccl::init(false); bk_ = "nccl"; } if (group == nullptr) { group = ring::init(false); bk_ = "ring"; } if (group == nullptr) { group = mpi::init(false); bk_ = "mpi"; } if (group == nullptr) { group = jaccl::init(false); bk_ = "jaccl"; } if (group == nullptr && strict) { throw std::runtime_error("[distributed] Couldn't initialize any backend"); } } else { std::ostringstream msg; msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', " << "'jaccl' and 'ring' but '" << bk << "' was provided."; throw std::invalid_argument(msg.str()); } if (group == nullptr) { group = std::make_shared(); } else { backends.insert({"any", group}); } backends.insert({std::move(bk_), group}); return Group(group); } } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/distributed.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/api.h" #include "mlx/array.h" #include "mlx/utils.h" namespace mlx::core::distributed { // Forward declaration of the base group implementation. namespace detail { class GroupImpl; }; /* Check if a communication backend is available */ MLX_API bool is_available(); MLX_API bool is_available(const std::string& bk); /** * A distributed::Group represents a group of independent mlx processes that * can communicate. We must also be able to create sub-groups from a group in * order to define more granular communication. */ struct MLX_API Group { Group(std::shared_ptr group) : group_(std::move(group)) {} int rank() const; int size() const; /** * Split the group according to the provided color. Namely processes that use * the same color will go to the same group. * * The key defines the rank of the processes in the new group. The smaller * the key the smaller the rank. If the provided key is negative, then the * rank in the current group is used. */ Group split(int color, int key = -1) const; const std::shared_ptr& raw_group() const { return group_; } private: std::shared_ptr group_{nullptr}; }; /** * Initialize the distributed backend and return the group containing all * discoverable processes. * * If strict is true then throw an error if we couldn't initialize the * distributed subsystem. Otherwise simply return a singleton group which will * render communication operations as no-op. */ MLX_API Group init(bool strict = false, const std::string& bk = "any"); } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/distributed_impl.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::detail { /** * Abstract base class of a distributed group implementation. */ class GroupImpl { public: virtual ~GroupImpl() {} // Choose the stream this communication group can operate on virtual Stream communication_stream(StreamOrDevice s = {}) = 0; // Group operations virtual int rank() = 0; virtual int size() = 0; virtual std::shared_ptr split(int color, int key = -1) = 0; // Actual communication operations virtual void all_sum(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0; virtual void recv(array& out, int src, Stream stream) = 0; virtual void all_max(const array& input, array& output, Stream stream) = 0; virtual void all_min(const array& input, array& output, Stream stream) = 0; virtual void sum_scatter(const array& input, array& output, Stream stream) = 0; }; /* Define the MLX stream that the communication should happen in. */ Stream communication_stream(Group group, StreamOrDevice s = {}); /* Perform an all reduce sum operation */ void all_sum(Group group, const array& input, array& output, Stream stream); /* Perform an all gather operation */ void all_gather(Group group, const array& input, array& output, Stream stream); /** Send an array to the dst rank */ void send(Group group, const array& input, int dst, Stream stream); /** Recv an array from the src rank */ void recv(Group group, array& out, int src, Stream stream); /** Max reduction */ void all_max(Group group, const array& input, array& output, Stream stream); /** Min reduction */ void all_min(Group group, const array& input, array& output, Stream stream); /** Reduce scatter with average operation */ void sum_scatter(Group group, const array& input, array& output, Stream stream); } // namespace mlx::core::distributed::detail ================================================ FILE: mlx/distributed/jaccl/CMakeLists.txt ================================================ if(MLX_BUILD_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" AND MACOS_SDK_VERSION GREATER_EQUAL 26.2) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mesh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp) endif() ================================================ FILE: mlx/distributed/jaccl/jaccl.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include #include #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/jaccl/mesh.h" #include "mlx/distributed/jaccl/ring.h" #include "mlx/distributed/jaccl/utils.h" using GroupImpl = mlx::core::distributed::detail::GroupImpl; using json = nlohmann::json; namespace { struct DeviceFile { DeviceFile(const char* dev_file) { std::ifstream f(dev_file); json devices = json::parse(f); if (!devices.is_array()) { throw std::runtime_error( "[jaccl] The device file should start with an array"); } devices_.resize(devices.size()); for (int rank = 0; rank < devices.size(); rank++) { auto conn = devices[rank]; if (!conn.is_array()) { throw std::runtime_error( "[jaccl] The device file should have an array of arrays"); } if (conn.size() != devices_.size()) { std::ostringstream msg; msg << "[jaccl] The device file should contain the connectivity of each rank to " << "all other ranks but rank " << rank << " contains only " << conn.size() << " entries."; throw std::runtime_error(msg.str()); } devices_[rank].resize(conn.size()); for (int dst = 0; dst < conn.size(); dst++) { auto names = conn[dst]; if (names.is_string()) { devices_[rank][dst].push_back(names); } else if (names.is_array()) { for (auto name_it = names.begin(); name_it != names.end(); name_it++) { devices_[rank][dst].push_back(*name_it); } } else if (!names.is_null()) { throw std::runtime_error( "[jaccl] Device names should be null, a string or array of strings."); } } } } int size() { return devices_.size(); } bool is_valid_mesh() { for (int src = 0; src < size(); src++) { for (int dst = 0; dst < size(); dst++) { if (devices_[src][dst].size() != static_cast(src != dst)) { return false; } } } return true; } bool is_valid_ring() { int num_connections = devices_[0][1].size(); if (num_connections == 0) { return false; } for (int src = 0; src < size(); src++) { int left = (src + size() - 1) % size(); int right = (src + 1) % size(); for (int dst = 0; dst < size(); dst++) { if (dst != left && dst != right) { if (devices_[src][dst].size() != 0) { return false; } } else { if (devices_[src][dst].size() != num_connections) { return false; } } } } return true; } std::vector extract_mesh_connectivity(int rank) { std::vector devices(size()); for (int dst = 0; dst < size(); dst++) { if (dst != rank) { devices[dst] = devices_[rank][dst][0]; } } return devices; } std::pair, std::vector> extract_ring_connectivity(int rank) { int left = (rank + size() - 1) % size(); int right = (rank + 1) % size(); return std::make_pair(devices_[rank][left], devices_[rank][right]); } std::vector>> devices_; }; } // namespace namespace mlx::core::distributed::jaccl { bool is_available() { return ibv().is_available(); } std::shared_ptr init(bool strict /* = false */) { const char* dev_file = std::getenv("MLX_IBV_DEVICES"); const char* coordinator = std::getenv("MLX_JACCL_COORDINATOR"); const char* rank_str = std::getenv("MLX_RANK"); const char* ring = std::getenv("MLX_JACCL_RING"); if (!is_available() || !dev_file || !coordinator || !rank_str) { if (strict) { std::ostringstream msg; msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), " << "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_JACCL_COORDINATOR) " << "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "") << "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "") << "\" and MLX_JACCL_COORDINATOR=\"" << ((coordinator) ? coordinator : ""); throw std::runtime_error(msg.str()); } return nullptr; } auto rank = std::atoi(rank_str); bool prefer_ring = ring != nullptr; DeviceFile devices(dev_file); if (rank >= devices.size() || rank < 0) { std::ostringstream msg; msg << "[jaccl] Invalid rank " << rank << ". It should be between 0 and " << devices.size(); throw std::runtime_error(msg.str()); } if (prefer_ring && devices.is_valid_ring()) { auto [left, right] = devices.extract_ring_connectivity(rank); return std::make_shared( rank, devices.size(), left, right, coordinator); } else if (devices.is_valid_mesh()) { auto device_names = devices.extract_mesh_connectivity(rank); return std::make_shared(rank, device_names, coordinator); } else if (devices.is_valid_ring()) { auto [left, right] = devices.extract_ring_connectivity(rank); return std::make_shared( rank, devices.size(), left, right, coordinator); } else { throw std::runtime_error( "[jaccl] The device file should define a valid mesh or a valid ring."); } } } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/jaccl.h ================================================ // Copyright © 2025 Apple Inc. #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::jaccl { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available(); std::shared_ptr init(bool strict = false); } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/mesh.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/distributed/jaccl/mesh.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/reduction_ops.h" #include "mlx/dtype_utils.h" namespace mlx::core::distributed::jaccl { MeshGroup::MeshGroup( int rank, const std::vector& device_names, const char* coordinator_addr) : rank_(rank), size_(device_names.size()), side_channel_(rank_, size_, coordinator_addr), connections_(create_connections(device_names)) { if (size_ > MESH_MAX_PEERS) { std::ostringstream msg; msg << "[jaccl] The JACCL mesh supports up to " << MESH_MAX_PEERS << " peers but " << size_ << " were provided."; throw std::runtime_error(msg.str()); } // Initialize all the connections and allocate buffers initialize(); // Make sure every node has reached here before continuing side_channel_.all_gather(0); // Create the mesh implementation object mesh_ = MeshImpl(rank_, size_, connections_, buffers_); ring_ = RingImpl( rank_, size_, &connections_[(rank_ + size_ - 1) % size_], &connections_[(rank_ + 1) % size_], 1, ring_send_buffers_, ring_recv_buffers_); } void MeshGroup::initialize() { // Create the queue pairs for (auto& conn : connections_) { if (conn.ctx == nullptr) { continue; } conn.allocate_protection_domain(); conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR); conn.create_queue_pair(); } allocate_buffers(); // First init all connections for (int peer = 0; peer < size_; peer++) { if (peer == rank_) { continue; } connections_[peer].queue_pair_init(); } // Gather the information to be exchanged, this also serves as a barrier so // that all peers have initialized their connections before attempting to // transition to RTS. std::vector info; for (auto& conn : connections_) { info.emplace_back(conn.info()); } auto all_infos = side_channel_.all_gather(info); // Transition queue pairs to RTS for (int peer = 0; peer < size_; peer++) { if (peer == rank_) { continue; } auto peer_info = all_infos[peer][rank_]; connections_[peer].queue_pair_rtr(peer_info); connections_[peer].queue_pair_rts(); } } void MeshGroup::allocate_buffers() { // Deregister any buffers and free the memory buffers_.clear(); ring_send_buffers_.clear(); ring_recv_buffers_.clear(); // Allocate the memory for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { // Mesh buffers for (int j = 0; j < size_; j++) { buffers_.emplace_back(FRAME_SIZE * (1 << k)); } // Ring buffers (1 for each direction) for (int j = 0; j < 2; j++) { ring_send_buffers_.emplace_back(FRAME_SIZE * (1 << k)); ring_recv_buffers_.emplace_back(FRAME_SIZE * (1 << k)); } } } for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { // Mesh buffers for (int j = 0; j < size_; j++) { // This is our send buffer so register it with all pds so we can send // it to all connected devices. if (j == rank_) { for (auto& conn : connections_) { if (conn.ctx != nullptr) { buffers_[k * NUM_BUFFERS * size_ + i * size_ + j] .register_to_protection_domain(conn.protection_domain); } } } // This is the recv buffer from rank j so register it to rank j's // protection domain. else { buffers_[k * NUM_BUFFERS * size_ + i * size_ + j] .register_to_protection_domain(connections_[j].protection_domain); } } // Ring buffers (see ring group for the logic below) // We register send buffers to both the right and the left. int left = (rank_ + size_ - 1) % size_; int right = (rank_ + 1) % size_; ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0] .register_to_protection_domain(connections_[right].protection_domain); ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0] .register_to_protection_domain(connections_[left].protection_domain); ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1] .register_to_protection_domain(connections_[left].protection_domain); ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1] .register_to_protection_domain(connections_[right].protection_domain); } } } void MeshGroup::all_sum(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::SumOp{}); }); } void MeshGroup::all_max(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::MaxOp{}); }); } void MeshGroup::all_min(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::MinOp{}); }); } void MeshGroup::all_gather(const array& input, array& output, Stream stream) { auto in_ptr = input.data(); auto out_ptr = output.data(); size_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { mesh_.all_gather(in_ptr, out_ptr, n_bytes); }); } void MeshGroup::send(const array& input, int dst, Stream stream) { auto data = input.data(); int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.dispatch( [data, n_bytes, dst, this]() { mesh_.send(data, n_bytes, dst); }); } void MeshGroup::recv(array& out, int src, Stream stream) { auto data = out.data(); int64_t n_bytes = out.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); encoder.dispatch( [data, n_bytes, src, this]() { mesh_.recv(data, n_bytes, src); }); } template void MeshGroup::all_reduce( const array& input, array& output, Stream stream, ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); int64_t size = input.size(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, size, this, reduce_op]() { if (size_ > 2 && ((std::is_same_v && size > 65536) || size >= 8 * 1024 * 1024 / sizeof(T))) { ring_.all_reduce<2>(in_ptr, out_ptr, size, 1, reduce_op); } else { mesh_.all_reduce(in_ptr, out_ptr, size, reduce_op); } }); } } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/mesh.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/jaccl/mesh_impl.h" #include "mlx/distributed/jaccl/ring_impl.h" #include "mlx/distributed/jaccl/utils.h" using GroupImpl = mlx::core::distributed::detail::GroupImpl; namespace mlx::core::distributed::jaccl { /** * The JACCL communication group for a fully connected mesh. We expect one * connection per peer and it should be the lowest latency communication group * for small to medium size messages. * * Like all JACCL groups it uses a side channel to exchange the necessary * information and then configure the connections to be ready for RDMA * operations. */ class MeshGroup : public GroupImpl { public: MeshGroup( int rank, const std::vector& device_names, const char* coordinator_addr); Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::cpu); } int rank() override { return rank_; } int size() override { return size_; } void all_sum(const array& input, array& output, Stream stream) override; void all_max(const array& input, array& output, Stream stream) override; void all_min(const array& input, array& output, Stream stream) override; void all_gather(const array& input, array& output, Stream stream) override; void send(const array& input, int dst, Stream stream) override; void recv(array& out, int src, Stream stream) override; void sum_scatter(const array& input, array& output, Stream stream) override { throw std::runtime_error("[jaccl] sum_scatter not supported."); } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[jaccl] Group split not supported."); } private: template void all_reduce( const array& input, array& output, Stream stream, ReduceOp reduce_op); /** * Performs the connection initialization. Namely, after this call all * Connection objects should have a queue pair in RTS state and all buffers * should have been allocated. */ void initialize(); /** * Allocate all the buffers that we will use in the communication group. */ void allocate_buffers(); int rank_; int size_; SideChannel side_channel_; std::vector connections_; std::vector buffers_; std::vector ring_send_buffers_; std::vector ring_recv_buffers_; MeshImpl mesh_; RingImpl ring_; }; } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/mesh_impl.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include #include "mlx/distributed/jaccl/utils.h" constexpr int MESH_MAX_PEERS = 8; namespace mlx::core::distributed::jaccl { class MeshImpl { public: MeshImpl( int rank, int size, std::vector& conns, std::vector& buffers) : rank_(rank), size_(size), connections_(conns), buffers_(buffers) {} MeshImpl() : rank_(0), size_(1) {} template void all_reduce(const T* in_ptr, T* out_ptr, int64_t size, ReduceOp reduce_op) { // If not inplace all reduce then copy the input to the output first if (in_ptr != out_ptr) { std::memcpy(out_ptr, in_ptr, size * sizeof(T)); } // Fully connected all reduce T* data = out_ptr; auto [sz, buffer_size] = buffer_size_from_message(size * sizeof(T)); int64_t N = buffer_size / sizeof(T); constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; int64_t total = static_cast(size); int num_peers = size_ - 1; // Counters to maintain the state of transfers int in_flight = 0; int64_t read_offset = 0; int completed_send_count[PIPELINE] = {0}; int completed_recv_begin[MESH_MAX_PEERS] = {0}; int completed_recv_end[MESH_MAX_PEERS] = {0}; // Prefill the pipeline int buff = 0; while (read_offset < total && buff < PIPELINE) { post_recv_all(sz, buff); std::copy( data + read_offset, data + std::min(read_offset + N, total), send_buffer(sz, buff).begin()); post_send_all(sz, buff); buff++; in_flight += 2 * num_peers; read_offset += N; } // Main loop // // Keep going until we have no longer data in flight. while (in_flight > 0) { // Poll the hardware for completions. // // If a send was completed mark how many completions we have received // for that buffer. If we have sent the buffer to all peers we can // reuse the buffer so copy the next chunk of data and send it to all. // // If a receive is completed then advance the pointer of completed // receives. ibv_wc wc[WC_NUM]; int n = poll(connections_, WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; int rank = wc[i].wr_id & 0xff; in_flight--; if (work_type == SEND_WR && read_offset < total) { completed_send_count[buff]++; if (completed_send_count[buff] == num_peers) { std::copy( data + read_offset, data + std::min(read_offset + N, total), send_buffer(sz, buff).begin()); post_send_all(sz, buff); completed_send_count[buff] = 0; in_flight += num_peers; read_offset += N; } } else if (work_type == RECV_WR) { completed_recv_end[rank]++; } } // Process the completed recv // // For each rank we have a range of completed recv defined by a begin // and end inclusive and exlusive in standard C++ fashion. // // When there is an unprocessed receive we first check if we have // finished sending the write location. If so then we reduce in-place // and then check if there is more to be received and post a recv. for (int r = 0; r < size_; r++) { int s = completed_recv_begin[r]; int e = completed_recv_end[r]; int w = s * N; while (w < read_offset && e - s > 0) { int buff = s % PIPELINE; reduce_op( recv_buffer(sz, buff, r).begin(), data + w, std::min(N, total - w)); w += N; s++; if (w + (PIPELINE - 1) * N < total) { recv_from(sz, r, buff); in_flight++; } } completed_recv_begin[r] = s; } } } void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) { // Copy our data to the appropriate place std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); // Fully connected all gather char* data = out_ptr; char* our_data = out_ptr + rank_ * n_bytes; auto [sz, N] = buffer_size_from_message(n_bytes); constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; int64_t total = static_cast(n_bytes); int num_peers = size_ - 1; // Counters to maintain the state of transfers int in_flight = 0; int read_offset = 0; int completed_send_count[PIPELINE] = {0}; int write_offset[MESH_MAX_PEERS] = {0}; // Prefill the pipeline int buff = 0; while (read_offset < total && buff < PIPELINE) { post_recv_all(sz, buff); std::copy( our_data + read_offset, our_data + std::min(read_offset + N, total), send_buffer(sz, buff).begin()); post_send_all(sz, buff); buff++; in_flight += 2 * num_peers; read_offset += N; } // Main loop // // Keep going until we have no longer data in flight. while (in_flight > 0) { ibv_wc wc[WC_NUM]; int n = poll(connections_, WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; int rank = wc[i].wr_id & 0xff; in_flight--; // Send completed. If all sends completed then send the next chunk. if (work_type == SEND_WR && read_offset < total) { completed_send_count[buff]++; if (completed_send_count[buff] == num_peers) { std::copy( our_data + read_offset, our_data + std::min(read_offset + N, total), send_buffer(sz, buff).begin()); post_send_all(sz, buff); completed_send_count[buff] = 0; in_flight += num_peers; read_offset += N; } } // Recv completed. If we have more chunks then post another recv. else if (work_type == RECV_WR) { std::copy( recv_buffer(sz, buff, rank).begin(), recv_buffer(sz, buff, rank).begin() + std::min(N, total - write_offset[rank]), data + rank * n_bytes + write_offset[rank]); write_offset[rank] += N; if (write_offset[rank] + N * (PIPELINE - 1) < total) { recv_from(sz, rank, buff); in_flight++; } } } } } void send(const char* in_ptr, int64_t n_bytes, int dst) { constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE; auto [sz, N] = buffer_size_from_message(n_bytes); int in_flight = 0; int64_t read_offset = 0; // Prefill the pipeline int buff = 0; while (read_offset < n_bytes && buff < PIPELINE) { std::copy( in_ptr + read_offset, in_ptr + std::min(read_offset + N, n_bytes), send_buffer(sz, buff).begin()); send_to(sz, dst, buff); buff++; read_offset += N; in_flight++; } // Main loop while (in_flight > 0) { // Poll the hardware for completions. // // If a send was completed and we have more data to send then go ahead // and send them. ibv_wc wc[WC_NUM]; int n = connections_[dst].poll(WC_NUM, wc); for (int i = 0; i < n; i++) { int buff = (wc[i].wr_id >> 8) & 0xff; int rank = wc[i].wr_id & 0xff; in_flight--; if (read_offset < n_bytes) { std::copy( in_ptr + read_offset, in_ptr + std::min(read_offset + N, n_bytes), send_buffer(sz, buff).begin()); send_to(sz, dst, buff); read_offset += N; in_flight++; } } } } void recv(char* out_ptr, int64_t n_bytes, int src) { constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE; auto [sz, N] = buffer_size_from_message(n_bytes); int in_flight = 0; int64_t write_offset = 0; // Prefill the pipeline int buff = 0; while (N * buff < n_bytes && buff < PIPELINE) { recv_from(sz, src, buff); in_flight++; buff++; } // Main loop while (in_flight > 0) { // Poll the hardware for completions. // // If a recv was completed copy it to the output and if we have more // data to fetch post another recv. ibv_wc wc[WC_NUM]; int n = connections_[src].poll(WC_NUM, wc); for (int i = 0; i < n; i++) { int buff = (wc[i].wr_id >> 8) & 0xff; int rank = wc[i].wr_id & 0xff; in_flight--; std::copy( recv_buffer(sz, buff, src).begin(), recv_buffer(sz, buff, src).begin() + std::min(n_bytes - write_offset, static_cast(N)), out_ptr + write_offset); write_offset += N; if (write_offset + (PIPELINE - 1) * N < n_bytes) { recv_from(sz, src, buff); in_flight++; } } } } private: void send_to(int sz, int rank, int buff) { connections_[rank].post_send( send_buffer(sz, buff), SEND_WR << 16 | buff << 8 | rank); } void recv_from(int sz, int rank, int buff) { connections_[rank].post_recv( recv_buffer(sz, buff, rank), RECV_WR << 16 | buff << 8 | rank); } SharedBuffer& send_buffer(int sz, int buff) { return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank_]; } SharedBuffer& recv_buffer(int sz, int buff, int rank) { return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank]; } void post_send_all(int sz, int buff) { auto& b = send_buffer(sz, buff); int wr_id = SEND_WR << 16 | buff << 8; for (int i = 0; i < size_; i++) { if (i == rank_) { continue; } connections_[i].post_send(b, wr_id | i); } } void post_recv_all(int sz, int buff) { int b = sz * NUM_BUFFERS * size_ + buff * size_; int wr_id = RECV_WR << 16 | buff << 8; for (int i = 0; i < size_; i++) { if (i == rank_) { continue; } connections_[i].post_recv(buffers_[b + i], wr_id | i); } } int rank_; int size_; std::span connections_; std::span buffers_; }; } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/no_jaccl.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/distributed/jaccl/jaccl.h" namespace mlx::core::distributed::jaccl { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available() { return false; } std::shared_ptr init(bool strict /* = false */) { if (strict) { throw std::runtime_error("Cannot initialize jaccl distributed backend."); } return nullptr; } } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/ring.cpp ================================================ // Copyright © 2026 Apple Inc. #include "mlx/distributed/jaccl/ring.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/reduction_ops.h" #include "mlx/dtype_utils.h" namespace mlx::core::distributed::jaccl { RingGroup::RingGroup( int rank, int size, const std::vector& left_devices, const std::vector& right_devices, const char* coordinator_addr) : rank_(rank), size_(size), n_conns_(left_devices.size()), side_channel_(rank_, size_, coordinator_addr), left_(create_connections(left_devices)), right_(create_connections(right_devices)) { if (left_.size() > RING_MAX_CONNS || right_.size() > RING_MAX_CONNS) { std::ostringstream msg; msg << "[jaccl] Up to " << RING_MAX_CONNS << " per direction supported but " << left_.size() << " were provided."; throw std::runtime_error(msg.str()); } // Initialize all the connections and allocate buffers initialize(); // Make sure every node has reached here before continuing side_channel_.all_gather(0); // Create the ring implementation object ring_ = RingImpl(rank_, size_, left_, right_, send_buffers_, recv_buffers_); } void RingGroup::initialize() { // Create the queue pairs for (auto& conn : left_) { conn.allocate_protection_domain(); conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR); conn.create_queue_pair(); } for (auto& conn : right_) { conn.allocate_protection_domain(); conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR); conn.create_queue_pair(); } // Allocate the buffers allocate_buffers(); // Initialize the conections for (auto& conn : left_) { conn.queue_pair_init(); } for (auto& conn : right_) { conn.queue_pair_init(); } // Gather the information to be exchanged, this also serves as a barrier so // that all peers have initialized their connections before attempting to // transition to RTS. std::vector left_info; for (auto& conn : left_) { left_info.emplace_back(conn.info()); } std::vector right_info; for (auto& conn : right_) { right_info.emplace_back(conn.info()); } auto all_left_infos = side_channel_.all_gather(left_info); auto all_right_infos = side_channel_.all_gather(right_info); // Transition queue pairs to RTS int left_peer = (rank_ + size_ - 1) % size_; for (int i = 0; i < left_.size(); i++) { auto peer_info = all_right_infos[left_peer][i]; left_[i].queue_pair_rtr(peer_info); left_[i].queue_pair_rts(); } int right_peer = (rank_ + 1) % size_; for (int i = 0; i < right_.size(); i++) { auto peer_info = all_left_infos[right_peer][i]; right_[i].queue_pair_rtr(peer_info); right_[i].queue_pair_rts(); } } void RingGroup::allocate_buffers() { // Deregister any buffers and free the memory send_buffers_.clear(); recv_buffers_.clear(); // Allocate the memory for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { for (int j = 0; j < n_conns_ * 2; j++) { send_buffers_.emplace_back(FRAME_SIZE * (1 << k)); recv_buffers_.emplace_back(FRAME_SIZE * (1 << k)); } } } // Register the buffers with the corresponding connections for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { for (int j = 0; j < n_conns_ * 2; j++) { int wire = j % n_conns_; int lr = j / n_conns_; if (lr) { send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(left_[wire].protection_domain); recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(right_[wire].protection_domain); } else { send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(right_[wire].protection_domain); recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(left_[wire].protection_domain); } } } } } void RingGroup::all_sum(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::SumOp{}); }); } void RingGroup::all_max(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::MaxOp{}); }); } void RingGroup::all_min(const array& input, array& output, Stream stream) { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); all_reduce(input, output, stream, detail::MinOp{}); }); } void RingGroup::all_gather(const array& input, array& output, Stream stream) { auto in_ptr = input.data(); auto out_ptr = output.data(); int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { ring_.all_gather(in_ptr, out_ptr, n_bytes, n_conns_); }); } void RingGroup::send(const array& input, int dst, Stream stream) { int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (dst != right && dst != left) { std::ostringstream msg; msg << "[jaccl] In ring mode send is only supported to direct neighbors " << "but tried to send to " << dst << " from " << rank_ << std::endl; throw std::runtime_error(msg.str()); } auto data = input.data(); int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.dispatch([data, n_bytes, dst, this]() { ring_.send(data, n_bytes, dst, n_conns_); }); } void RingGroup::recv(array& out, int src, Stream stream) { int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (src != right && src != left) { std::ostringstream msg; msg << "[jaccl] In ring mode recv is only supported to direct neighbors " << "but tried to recv from " << src << " to " << rank_ << std::endl; throw std::runtime_error(msg.str()); } auto data = out.data(); int64_t n_bytes = out.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); encoder.dispatch([data, n_bytes, src, this]() { ring_.recv(data, n_bytes, src, n_conns_); }); } template void RingGroup::all_reduce( const array& input, array& output, Stream stream, ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); int64_t size = input.size(); int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, size, n_bytes, this, reduce_op]() { if (size < size_ * 2 * n_conns_) { ring_.all_reduce<1, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op); return; } if (n_bytes <= 65536) { ring_.all_reduce<2, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op); return; } ring_.all_reduce<2, T, ReduceOp>( in_ptr, out_ptr, size, n_conns_, reduce_op); }); } } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/ring.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/jaccl/ring_impl.h" #include "mlx/distributed/jaccl/utils.h" using GroupImpl = mlx::core::distributed::detail::GroupImpl; namespace mlx::core::distributed::jaccl { /** * The JACCL communication group for a ring where each node is connected to its * two neighboring nodes. It should be the highest bandwidth communication * group for large messages when many connections per peer are used. * * Like all JACCL groups it uses a side channel to exchange the necessary * information and then configure the connections to be ready for RDMA * operations. */ class RingGroup : public GroupImpl { public: RingGroup( int rank, int size, const std::vector& left_devices, const std::vector& right_devices, const char* coordinator_addr); Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::cpu); } int rank() override { return rank_; } int size() override { return size_; } void all_sum(const array& input, array& output, Stream stream) override; void all_max(const array& input, array& output, Stream stream) override; void all_min(const array& input, array& output, Stream stream) override; void all_gather(const array& input, array& output, Stream stream) override; void send(const array& input, int dst, Stream stream) override; void recv(array& out, int src, Stream stream) override; void sum_scatter(const array& input, array& output, Stream stream) override { throw std::runtime_error("[jaccl] sum_scatter not supported."); } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[jaccl] Group split not supported."); } private: template void all_reduce( const array& input, array& output, Stream stream, ReduceOp reduce_op); /** * Performs the connection initialization. Namely, after this call all * Connection objects should have a queue pair in RTS state and all buffers * should have been allocated. */ void initialize(); /** * Allocate all the buffers that we will use in the communication group. */ void allocate_buffers(); int rank_; int size_; int n_conns_; SideChannel side_channel_; std::vector left_; std::vector right_; std::vector send_buffers_; std::vector recv_buffers_; RingImpl ring_; }; } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/ring_impl.h ================================================ // Copyright © 2026 Apple Inc. #pragma once #include #include "mlx/distributed/jaccl/utils.h" constexpr int RING_MAX_CONNS = 4; namespace mlx::core::distributed::jaccl { class RingImpl { public: RingImpl( int rank, int size, std::vector& left, std::vector& right, std::vector& send_buffers, std::vector& recv_buffers) : rank_(rank), size_(size), n_conns_(left.size()), left_(left), right_(right), send_buffers_(send_buffers), recv_buffers_(recv_buffers) {} RingImpl( int rank, int size, Connection* left_begin, Connection* right_begin, size_t n_conns, std::vector& send_buffers, std::vector& recv_buffers) : rank_(rank), size_(size), n_conns_(n_conns), left_(left_begin, n_conns), right_(right_begin, n_conns), send_buffers_(send_buffers), recv_buffers_(recv_buffers) {} RingImpl() : rank_(0), size_(1), n_conns_(0) {} template void all_reduce( const T* in_ptr, T* out_ptr, int64_t size, int n_wires, ReduceOp reduce_op) { // If not inplace all reduce then copy the input to the output first if (in_ptr != out_ptr) { std::memcpy(out_ptr, in_ptr, size * sizeof(T)); } constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * MAX_DIR; int64_t chunk_size = (size + size_ - 1) / size_; int64_t size_per_wire = (chunk_size + (MAX_DIR * n_wires) - 1) / (MAX_DIR * n_wires); auto [sz, N] = buffer_size_from_message(size_per_wire * sizeof(T)); N /= sizeof(T); int64_t n_steps = (size_per_wire + N - 1) / N; // Counters to maintain the state of transfers int in_flight = 0; int64_t chunk_multiple_size = size_ * chunk_size; int64_t send_offset[MAX_DIR]; int64_t recv_offset[MAX_DIR]; int64_t send_limits[MAX_DIR]; int64_t recv_limits[MAX_DIR]; int send_count[MAX_DIR * RING_MAX_CONNS] = {0}; int recv_count[MAX_DIR * RING_MAX_CONNS] = {0}; send_offset[0] = rank_ * chunk_size; recv_offset[0] = ((rank_ + size_ - 1) % size_) * chunk_size; if constexpr (MAX_DIR == 2) { send_offset[1] = rank_ * chunk_size; recv_offset[1] = ((rank_ + 1) % size_) * chunk_size; send_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - send_offset[0])); send_limits[1] = std::min(chunk_size, std::max(0, size - send_offset[1])); recv_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - recv_offset[0])); recv_limits[1] = std::min(chunk_size, std::max(0, size - recv_offset[1])); } else { send_limits[0] = std::min(chunk_size, std::max(0, size - send_offset[0])); recv_limits[0] = std::min(chunk_size, std::max(0, size - recv_offset[0])); } // First reduce scatter // // Possible perf improvement by not syncing at every step but running ahead // as needed. for (int k = 0; k < size_ - 1; k++) { // Prefill the pipeline int buff = 0; while (buff < n_steps && buff < PIPELINE) { post_recv_all(sz, buff, n_wires); for (int lr = 0; lr < MAX_DIR; lr++) { for (int lw = 0; lw < n_wires; lw++) { int64_t offset = lw * N + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + lr * n_wires * size_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, send_limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_count[lr * RING_MAX_CONNS + lw]++; } } post_send_all(sz, buff, n_wires); buff++; in_flight += 2 * MAX_DIR * n_wires; } // Main loop // // Keep going until we have no longer data in flight. while (in_flight > 0) { ibv_wc wc[WC_NUM]; int n = poll(left_, right_, WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; int wire = wc[i].wr_id & 0xff; int lr = wire / RING_MAX_CONNS; int lw = wire % RING_MAX_CONNS; in_flight--; if (work_type == SEND_WR && send_count[wire] < n_steps) { int64_t offset = lw * N + send_count[wire] * n_wires * N + lr * n_wires * size_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, send_limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_to(sz, buff, lr, lw); in_flight++; send_count[wire]++; } else if (work_type == RECV_WR) { int64_t offset = lw * N + recv_count[wire] * n_wires * N + lr * n_wires * size_per_wire; reduce_op( recv_buffer(sz, buff, lr, lw).begin(), out_ptr + recv_offset[lr] + offset, std::max(0, std::min(N, recv_limits[lr] - offset))); recv_count[wire]++; if (recv_count[wire] + (PIPELINE - 1) < n_steps) { recv_from(sz, buff, lr, lw); in_flight++; } } } } send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % chunk_multiple_size; recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % chunk_multiple_size; if constexpr (MAX_DIR == 2) { send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; send_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - send_offset[0])); send_limits[1] = std::min(chunk_size, std::max(0, size - send_offset[1])); recv_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - recv_offset[0])); recv_limits[1] = std::min(chunk_size, std::max(0, size - recv_offset[1])); } else { send_limits[0] = std::min(chunk_size, std::max(0, size - send_offset[0])); recv_limits[0] = std::min(chunk_size, std::max(0, size - recv_offset[0])); } for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { send_count[i] = recv_count[i] = 0; } } // Secondly all gather // // The offsets are correct from the scatter reduce for (int k = 0; k < size_ - 1; k++) { // Prefill the pipeline int buff = 0; while (buff < n_steps && buff < PIPELINE) { post_recv_all(sz, buff, n_wires); for (int lr = 0; lr < MAX_DIR; lr++) { for (int lw = 0; lw < n_wires; lw++) { int64_t offset = lw * N + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + lr * n_wires * size_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, send_limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_count[lr * RING_MAX_CONNS + lw]++; } } post_send_all(sz, buff, n_wires); buff++; in_flight += 2 * MAX_DIR * n_wires; } // Main loop // // Keep going until we have no longer data in flight. while (in_flight > 0) { ibv_wc wc[WC_NUM]; int n = poll(left_, right_, WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; int wire = wc[i].wr_id & 0xff; int lr = wire / RING_MAX_CONNS; int lw = wire % RING_MAX_CONNS; in_flight--; if (work_type == SEND_WR && send_count[wire] < n_steps) { int64_t offset = lw * N + send_count[wire] * n_wires * N + lr * n_wires * size_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, send_limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_to(sz, buff, lr, lw); in_flight++; send_count[wire]++; } else if (work_type == RECV_WR) { int64_t offset = lw * N + recv_count[wire] * n_wires * N + lr * n_wires * size_per_wire; std::copy( recv_buffer(sz, buff, lr, lw).begin(), recv_buffer(sz, buff, lr, lw).begin() + std::max(0, std::min(N, recv_limits[lr] - offset)), out_ptr + recv_offset[lr] + offset); recv_count[wire]++; if (recv_count[wire] + (PIPELINE - 1) < n_steps) { recv_from(sz, buff, lr, lw); in_flight++; } } } } send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % chunk_multiple_size; recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % chunk_multiple_size; if constexpr (MAX_DIR == 2) { send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; send_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - send_offset[0])); send_limits[1] = std::min(chunk_size, std::max(0, size - send_offset[1])); recv_limits[0] = std::min( n_wires * size_per_wire, std::max(0, size - recv_offset[0])); recv_limits[1] = std::min(chunk_size, std::max(0, size - recv_offset[1])); } else { send_limits[0] = std::min(chunk_size, std::max(0, size - send_offset[0])); recv_limits[0] = std::min(chunk_size, std::max(0, size - recv_offset[0])); } for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { send_count[i] = recv_count[i] = 0; } } } void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes, int n_wires) { // Copy our data to the appropriate place std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * 2; size_t n_bytes_per_wire = (n_bytes + (2 * n_wires) - 1) / (2 * n_wires); size_t out_bytes = n_bytes * size_; auto [sz, N] = buffer_size_from_message(n_bytes_per_wire); int n_steps = (n_bytes_per_wire + N - 1) / N; // Counters to maintain the state of transfers int in_flight = 0; int64_t send_offset[2]; int64_t recv_offset[2]; int64_t limits[2]; int send_count[2 * RING_MAX_CONNS] = {0}; int recv_count[2 * RING_MAX_CONNS] = {0}; send_offset[0] = send_offset[1] = rank_ * n_bytes; recv_offset[0] = ((rank_ + size_ - 1) % size_) * n_bytes; recv_offset[1] = ((rank_ + 1) % size_) * n_bytes; limits[0] = n_wires * n_bytes_per_wire; limits[1] = n_bytes; // Possible perf improvement by not syncing at every step but running ahead // as needed. for (int k = 0; k < size_ - 1; k++) { // Prefill the pipeline int buff = 0; while (buff < n_steps && buff < PIPELINE) { post_recv_all(sz, buff); for (int lr = 0; lr < 2; lr++) { for (int lw = 0; lw < n_wires; lw++) { int64_t offset = lw * N + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + lr * n_wires * n_bytes_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_count[lr * RING_MAX_CONNS + lw]++; } } post_send_all(sz, buff); buff++; in_flight += 2 * 2 * n_wires; } // Main loop // // Keep going until we have no longer data in flight. while (in_flight > 0) { ibv_wc wc[WC_NUM]; int n = poll(left_, right_, WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; int wire = wc[i].wr_id & 0xff; int lr = wire / RING_MAX_CONNS; int lw = wire % RING_MAX_CONNS; in_flight--; if (work_type == SEND_WR && send_count[wire] < n_steps) { int64_t offset = lw * N + send_count[wire] * n_wires * N + lr * n_wires * n_bytes_per_wire; std::copy( out_ptr + send_offset[lr] + offset, out_ptr + send_offset[lr] + std::max(offset, std::min(offset + N, limits[lr])), send_buffer(sz, buff, lr, lw).begin()); send_to(sz, buff, lr, lw); in_flight++; send_count[wire]++; } else if (work_type == RECV_WR) { int64_t offset = lw * N + recv_count[wire] * n_wires * N + lr * n_wires * n_bytes_per_wire; std::copy( recv_buffer(sz, buff, lr, lw).begin(), recv_buffer(sz, buff, lr, lw).begin() + std::max(0, std::min(N, limits[lr] - offset)), out_ptr + recv_offset[lr] + offset); recv_count[wire]++; if (recv_count[wire] + (PIPELINE - 1) < n_steps) { recv_from(sz, buff, lr, lw); in_flight++; } } } } send_offset[0] = (send_offset[0] + out_bytes - n_bytes) % out_bytes; recv_offset[0] = (recv_offset[0] + out_bytes - n_bytes) % out_bytes; send_offset[1] = (send_offset[1] + n_bytes) % out_bytes; recv_offset[1] = (recv_offset[1] + n_bytes) % out_bytes; for (int i = 0; i < 2 * RING_MAX_CONNS; i++) { send_count[i] = recv_count[i] = 0; } } } void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) { int left = (rank_ + size_ - 1) % size_; // In the case that size_ == 2 then left == right so we bias send towards // left and recv towards right so that the selections will be correct for // the 2 node case. auto& conns = (dst == left) ? left_ : right_; int dir = dst == left; constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; auto [sz, N] = buffer_size_from_message(bytes_per_wire); int in_flight = 0; int64_t read_offset[RING_MAX_CONNS]; int64_t limits[RING_MAX_CONNS]; for (int lw = 0; lw < n_wires; lw++) { read_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); } // Prefill the pipeline for (int lw = 0; lw < n_wires; lw++) { int buff = 0; while (read_offset[lw] < limits[lw] && buff < PIPELINE) { std::copy( in_ptr + read_offset[lw], in_ptr + std::min(read_offset[lw] + N, limits[lw]), send_buffer(sz, buff, dir, lw).begin()); send_to(sz, buff, dir, lw); buff++; read_offset[lw] += N; in_flight++; } } // Main loop while (in_flight > 0) { // Poll the hardware for completions. // // If a send was completed and we have more data to send then go ahead // and send them. ibv_wc wc[WC_NUM]; int n = poll(conns, WC_NUM, wc); for (int i = 0; i < n; i++) { int buff = (wc[i].wr_id >> 8) & 0xff; int wire = wc[i].wr_id & 0xff; int lw = wire % RING_MAX_CONNS; in_flight--; if (read_offset[lw] < limits[lw]) { std::copy( in_ptr + read_offset[lw], in_ptr + std::min(read_offset[lw] + N, limits[lw]), send_buffer(sz, buff, dir, lw).begin()); send_to(sz, buff, dir, lw); read_offset[lw] += N; in_flight++; } } } } void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) { int right = (rank_ + 1) % size_; // In the case that size_ == 2 then left == right so we bias send towards // left and recv towards right so that the selections will be correct for // the 2 node case. auto& conns = (src == right) ? right_ : left_; int dir = src == right; constexpr int PIPELINE = 2; constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; auto [sz, N] = buffer_size_from_message(bytes_per_wire); int in_flight = 0; int64_t write_offset[RING_MAX_CONNS]; int64_t limits[RING_MAX_CONNS]; for (int lw = 0; lw < n_wires; lw++) { write_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); } // Prefill the pipeline for (int lw = 0; lw < n_wires; lw++) { int buff = 0; while (N * buff < limits[lw] && buff < PIPELINE) { recv_from(sz, buff, dir, lw); buff++; in_flight++; } } // Main loop while (in_flight > 0) { // Poll the hardware for completions. // // If a recv was completed copy it to the output and if we have more // data to fetch post another recv. ibv_wc wc[WC_NUM]; int n = poll(conns, WC_NUM, wc); for (int i = 0; i < n; i++) { int buff = (wc[i].wr_id >> 8) & 0xff; int wire = wc[i].wr_id & 0xff; int lw = wire % RING_MAX_CONNS; in_flight--; std::copy( recv_buffer(sz, buff, dir, lw).begin(), recv_buffer(sz, buff, dir, lw).begin() + std::max( 0, std::min(limits[lw] - write_offset[lw], N)), out_ptr + write_offset[lw]); write_offset[lw] += N; if (write_offset[lw] + (PIPELINE - 1) * N < limits[lw]) { recv_from(sz, buff, dir, lw); in_flight++; } } } } private: void send_to(int sz, int buff, int left_right, int wire) { if (left_right) { left_[wire].post_send( send_buffer_left(sz, buff, wire), SEND_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); } else { right_[wire].post_send( send_buffer_right(sz, buff, wire), SEND_WR << 16 | buff << 8 | wire); } } void recv_from(int sz, int buff, int left_right, int wire) { if (left_right) { right_[wire].post_recv( recv_buffer_right(sz, buff, wire), RECV_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); } else { left_[wire].post_recv( recv_buffer_left(sz, buff, wire), RECV_WR << 16 | buff << 8 | wire); } } SharedBuffer& send_buffer_right(int sz, int buff, int wire) { return send_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; } SharedBuffer& send_buffer_left(int sz, int buff, int wire) { return send_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + wire]; } SharedBuffer& send_buffer(int sz, int buff, int left_right, int wire) { return send_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + left_right * n_conns_ + wire]; } SharedBuffer& recv_buffer_left(int sz, int buff, int wire) { return recv_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; } SharedBuffer& recv_buffer_right(int sz, int buff, int wire) { return recv_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + wire]; } SharedBuffer& recv_buffer(int sz, int buff, int left_right, int wire) { return recv_buffers_ [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + left_right * n_conns_ + wire]; } template void post_recv_all(int sz, int buff, int n_wires) { for (int lr = 0; lr < MAX_DIR; lr++) { for (int lw = 0; lw < n_wires; lw++) { recv_from(sz, buff, lr, lw); } } } void post_recv_all(int sz, int buff) { post_recv_all<2>(sz, buff, n_conns_); } template void post_send_all(int sz, int buff, int n_wires) { for (int lr = 0; lr < MAX_DIR; lr++) { for (int lw = 0; lw < n_wires; lw++) { send_to(sz, buff, lr, lw); } } } void post_send_all(int sz, int buff) { post_send_all<2>(sz, buff, n_conns_); } int rank_; int size_; int n_conns_; std::span left_; std::span right_; std::span send_buffers_; std::span recv_buffers_; }; } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include #include #include #include "mlx/distributed/jaccl/utils.h" #define LOAD_SYMBOL(symbol, variable) \ { \ variable = (decltype(variable))dlsym(librdma_handle_, #symbol); \ char* error = dlerror(); \ if (error != nullptr) { \ std::cerr << IBV_TAG << " " << error << std::endl; \ librdma_handle_ = nullptr; \ return; \ } \ } namespace { void* page_aligned_alloc(size_t num_bytes) { static size_t page_size = sysconf(_SC_PAGESIZE); void* buf; if (posix_memalign(&buf, page_size, num_bytes)) { return nullptr; } return buf; } } // namespace namespace mlx::core::distributed::jaccl { IBVWrapper::IBVWrapper() { librdma_handle_ = dlopen("librdma.dylib", RTLD_NOW | RTLD_GLOBAL); if (librdma_handle_ == nullptr) { return; } LOAD_SYMBOL(ibv_get_device_list, get_device_list); LOAD_SYMBOL(ibv_get_device_name, get_device_name); LOAD_SYMBOL(ibv_open_device, open_device); LOAD_SYMBOL(ibv_free_device_list, free_device_list); LOAD_SYMBOL(ibv_close_device, close_device); LOAD_SYMBOL(ibv_alloc_pd, alloc_pd); LOAD_SYMBOL(ibv_create_qp, create_qp); LOAD_SYMBOL(ibv_create_cq, create_cq); LOAD_SYMBOL(ibv_destroy_cq, destroy_cq); LOAD_SYMBOL(ibv_destroy_qp, destroy_qp); LOAD_SYMBOL(ibv_dealloc_pd, dealloc_pd); LOAD_SYMBOL(ibv_query_port, query_port); LOAD_SYMBOL(ibv_query_gid, query_gid); LOAD_SYMBOL(ibv_modify_qp, modify_qp); LOAD_SYMBOL(ibv_reg_mr, reg_mr); LOAD_SYMBOL(ibv_dereg_mr, dereg_mr); // Not really symbols but leaving them here in case they become symbols in // the future. // // LOAD_SYMBOL(ibv_post_send, post_send); // LOAD_SYMBOL(ibv_post_recv, post_recv); // LOAD_SYMBOL(ibv_poll_cq, poll_cq); } IBVWrapper& ibv() { static IBVWrapper wrapper; return wrapper; } SharedBuffer::SharedBuffer(size_t num_bytes) : data_(page_aligned_alloc(num_bytes)), num_bytes_(num_bytes) {} SharedBuffer::SharedBuffer(SharedBuffer&& b) : data_(nullptr), num_bytes_(0) { std::swap(data_, b.data_); std::swap(num_bytes_, b.num_bytes_); std::swap(memory_regions_, b.memory_regions_); } SharedBuffer::~SharedBuffer() { for (auto& [pd, mr] : memory_regions_) { ibv().dereg_mr(mr); } if (data_ != nullptr) { std::free(data_); } } void SharedBuffer::register_to_protection_domain(ibv_pd* protection_domain) { auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr}); if (!inserted) { throw std::runtime_error( "[jaccl] Buffer can be registered once per protection domain"); } it->second = ibv().reg_mr( protection_domain, data_, num_bytes_, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); if (!it->second) { throw std::runtime_error("[jaccl] Register memory region failed"); } } Connection::Connection(ibv_context* ctx_) : ctx(ctx_), protection_domain(nullptr), completion_queue(nullptr), queue_pair(nullptr) { src.local_id = -1; } Connection::Connection(Connection&& c) : Connection(nullptr) { std::swap(ctx, c.ctx); std::swap(protection_domain, c.protection_domain); std::swap(completion_queue, c.completion_queue); std::swap(queue_pair, c.queue_pair); std::swap(src, c.src); } Connection::~Connection() { if (queue_pair != nullptr) { ibv().destroy_qp(queue_pair); } if (completion_queue != nullptr) { ibv().destroy_cq(completion_queue); } if (protection_domain != nullptr) { ibv().dealloc_pd(protection_domain); } if (ctx != nullptr) { ibv().close_device(ctx); } } void Connection::allocate_protection_domain() { protection_domain = ibv().alloc_pd(ctx); if (protection_domain == nullptr) { throw std::runtime_error("[jaccl] Couldn't allocate protection domain"); } } void Connection::create_completion_queue(int num_entries) { completion_queue = ibv().create_cq(ctx, num_entries, nullptr, nullptr, 0); if (completion_queue == nullptr) { throw std::runtime_error("[jaccl] Couldn't create completion queue"); } } void Connection::create_queue_pair() { ibv_qp_init_attr init_attr; init_attr.qp_context = ctx; init_attr.qp_context = ctx; init_attr.send_cq = completion_queue; init_attr.recv_cq = completion_queue; init_attr.srq = nullptr; init_attr.cap.max_send_wr = MAX_SEND_WR; init_attr.cap.max_recv_wr = MAX_RECV_WR; init_attr.cap.max_send_sge = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_inline_data = 0; init_attr.qp_type = IBV_QPT_UC; init_attr.sq_sig_all = 0; queue_pair = ibv().create_qp(protection_domain, &init_attr); if (queue_pair == nullptr) { throw std::runtime_error("[jaccl] Couldn't create queue pair"); } } const Destination& Connection::info() { if (queue_pair == nullptr || src.local_id >= 0) { return src; } ibv_port_attr port_attr; ibv().query_port(ctx, 1, &port_attr); ibv_gid gid; ibv().query_gid(ctx, 1, 1, &gid); src.local_id = port_attr.lid; src.queue_pair_number = queue_pair->qp_num; src.packet_sequence_number = 7; // TODO: Change to sth random src.global_identifier = gid; return src; } void Connection::queue_pair_init() { ibv_qp_attr attr = {}; attr.qp_state = IBV_QPS_INIT; attr.port_num = 1; attr.pkey_index = 0; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; int mask = IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) { std::ostringstream msg; msg << "[jaccl] Changing queue pair to INIT failed with errno " << status; throw std::invalid_argument(msg.str()); } } void Connection::queue_pair_rtr(const Destination& dst) { ibv_qp_attr attr = {}; memset(&attr, 0, sizeof(attr)); attr.qp_state = IBV_QPS_RTR; attr.path_mtu = IBV_MTU_1024; attr.rq_psn = dst.packet_sequence_number; attr.dest_qp_num = dst.queue_pair_number; attr.ah_attr.dlid = dst.local_id; attr.ah_attr.sl = 0; attr.ah_attr.src_path_bits = 0; attr.ah_attr.port_num = 1; attr.ah_attr.is_global = 0; if (dst.global_identifier.global.interface_id) { attr.ah_attr.is_global = 1; attr.ah_attr.grh.hop_limit = 1; attr.ah_attr.grh.dgid = dst.global_identifier; attr.ah_attr.grh.sgid_index = 1; } int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN; if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) { std::ostringstream msg; msg << "[jaccl] Changing queue pair to RTR failed with errno " << status; throw std::invalid_argument(msg.str()); } } void Connection::queue_pair_rts() { ibv_qp_attr attr = {}; attr.qp_state = IBV_QPS_RTS; attr.sq_psn = src.packet_sequence_number; int mask = IBV_QP_STATE | IBV_QP_SQ_PSN; if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) { std::ostringstream msg; msg << "[jaccl] Changing queue pair to RTS failed with errno " << status; throw std::invalid_argument(msg.str()); } } std::vector create_connections( const std::vector& device_names) { std::vector connections; int num_devices = 0; ibv_device** devices = ibv().get_device_list(&num_devices); for (auto& name : device_names) { // Empty so add a nullptr context if (name.empty()) { connections.emplace_back(nullptr); continue; } // Search for the name and try to open the device for (int i = 0; i < num_devices; i++) { if (name == ibv().get_device_name(devices[i])) { auto ctx = ibv().open_device(devices[i]); if (ctx == nullptr) { std::ostringstream msg; msg << "[jaccl] Could not open device " << name; throw std::runtime_error(msg.str()); } connections.emplace_back(ctx); break; } } } ibv().free_device_list(devices); return connections; } SideChannel::SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) { auto address = detail::parse_address(addr); if (rank_ == 0) { detail::TCPSocket server(IBV_TAG); server.listen(IBV_TAG, address); for (int i = 0; i < size - 1; i++) { sockets_.push_back(server.accept(IBV_TAG)); } std::vector ranks(size - 1); for (int i = 0; i < size - 1; i++) { sockets_[i].recv( IBV_TAG, reinterpret_cast(&ranks[i]), sizeof(int)); ranks[i]--; } for (int i = 0; i < size - 1; i++) { while (i != ranks[i]) { std::swap(sockets_[i], sockets_[ranks[i]]); std::swap(ranks[i], ranks[ranks[i]]); } } } else { sockets_.push_back( detail::TCPSocket::connect( IBV_TAG, address, 4, 1000, [](int attempt, int wait) { std::cerr << IBV_TAG << " Connection attempt " << attempt << " waiting " << wait << " ms" << std::endl; })); sockets_[0].send(IBV_TAG, reinterpret_cast(&rank_), sizeof(int)); } } SideChannel::SideChannel(SideChannel&& sc) : rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) { sc.rank_ = -1; sc.size_ = -1; } } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/jaccl/utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include #include #include "mlx/distributed/utils.h" constexpr const char* IBV_TAG = "[jaccl]"; constexpr int SEND_WR = 1; constexpr int RECV_WR = 2; constexpr int MAX_SEND_WR = 32; constexpr int MAX_RECV_WR = 32; constexpr int BUFFER_SIZES = 8; constexpr int NUM_BUFFERS = 2; constexpr int FRAME_SIZE = 4096; namespace detail = mlx::core::distributed::detail; namespace { template struct is_container : std::false_type {}; template struct is_container< T, std::void_t> : std::true_type {}; inline std::pair buffer_size_from_message(int64_t msg) { if (__builtin_available(macOS 26.3, iOS 26.3, tvOS 26.3, visionOS 26.3, *)) { for (int k = BUFFER_SIZES - 1; k > 0; k--) { if (msg >= FRAME_SIZE * (1 << k)) { return {k, FRAME_SIZE * (1 << k)}; } } } return {0, FRAME_SIZE}; } } // namespace namespace mlx::core::distributed::jaccl { /** * Wrapper for the ibverbs API. */ struct IBVWrapper { IBVWrapper(); bool is_available() { return librdma_handle_ != nullptr; } // API ibv_device** (*get_device_list)(int*); const char* (*get_device_name)(ibv_device*); ibv_context* (*open_device)(ibv_device*); void (*free_device_list)(ibv_device**); int (*close_device)(ibv_context*); ibv_pd* (*alloc_pd)(ibv_context*); ibv_qp* (*create_qp)(ibv_pd*, ibv_qp_init_attr*); ibv_cq* (*create_cq)(ibv_context*, int, void*, ibv_comp_channel*, int); int (*destroy_cq)(ibv_cq*); int (*destroy_qp)(ibv_qp*); int (*dealloc_pd)(ibv_pd*); int (*query_port)(ibv_context*, uint8_t, ibv_port_attr*); int (*query_gid)(ibv_context*, uint8_t, int, ibv_gid*); int (*modify_qp)(ibv_qp*, ibv_qp_attr*, int); ibv_mr* (*reg_mr)(ibv_pd*, void*, size_t, int); int (*dereg_mr)(ibv_mr*); private: void* librdma_handle_; }; IBVWrapper& ibv(); /** * Contains the information that defines a destination to a remote device. * Basically we can compute our own destination and share it with remote hosts * over the side channel. */ struct Destination { int local_id; int queue_pair_number; int packet_sequence_number; ibv_gid global_identifier; }; /** * A buffer that can be registered to a number of protection domains. */ class SharedBuffer { public: SharedBuffer(size_t num_bytes); SharedBuffer(SharedBuffer&& b); ~SharedBuffer(); SharedBuffer(const SharedBuffer&) = delete; SharedBuffer& operator=(const SharedBuffer&) = delete; void register_to_protection_domain(ibv_pd* protection_domain); size_t size() const { return num_bytes_; } uint32_t local_key(ibv_pd* protection_domain) const { return memory_regions_.at(protection_domain)->lkey; } ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const { ibv_sge entry; entry.addr = reinterpret_cast(data_); entry.length = size(); entry.lkey = local_key(protection_domain); return entry; } template T* data() { return static_cast(data_); } template T* begin() { return static_cast(data_); } template T* end() { return static_cast(data_) + size() / sizeof(T); } private: void* data_; size_t num_bytes_; std::unordered_map memory_regions_; }; /** * Manipulates an RDMA connection. Enables (among other things) * * - Creating a queue pair * - Sending and receiving * - Checking completion */ struct Connection { ibv_context* ctx; ibv_pd* protection_domain; ibv_cq* completion_queue; ibv_qp* queue_pair; Destination src; // holds the local information Connection(ibv_context* ctx_); Connection(Connection&& c); Connection(const Connection&) = delete; Connection& operator=(Connection&) = delete; ~Connection(); void allocate_protection_domain(); void create_completion_queue(int num_entries); void create_queue_pair(); const Destination& info(); void queue_pair_init(); void queue_pair_rtr(const Destination& dst); void queue_pair_rts(); void post_send(const SharedBuffer& buff, uint64_t work_request_id) { ibv_send_wr work_request, *bad_work_request; auto entry = buff.to_scatter_gather_entry(protection_domain); work_request.wr_id = work_request_id; work_request.sg_list = &entry; work_request.num_sge = 1; work_request.opcode = IBV_WR_SEND; work_request.send_flags = IBV_SEND_SIGNALED; work_request.next = nullptr; if (int status = ibv_post_send(queue_pair, &work_request, &bad_work_request); status != 0) { std::ostringstream msg; msg << "[jaccl] Send failed with error code " << status; throw std::invalid_argument(msg.str()); } } void post_recv(const SharedBuffer& buff, uint64_t work_request_id) { ibv_recv_wr work_request, *bad_work_request; auto entry = buff.to_scatter_gather_entry(protection_domain); work_request.wr_id = work_request_id; work_request.sg_list = &entry; work_request.num_sge = 1; work_request.next = nullptr; if (int status = ibv_post_recv(queue_pair, &work_request, &bad_work_request); status != 0) { std::ostringstream msg; msg << "[jaccl] Recv failed with error code " << status; throw std::invalid_argument(msg.str()); } } int poll(int num_completions, ibv_wc* work_completions) { return ibv_poll_cq(completion_queue, num_completions, work_completions); } }; std::vector create_connections( const std::vector& device_names); inline int poll( std::span connections, int num_completions, ibv_wc* work_completions) { int completions = 0; for (auto& c : connections) { if (c.ctx == nullptr) { continue; } if (completions >= num_completions) { return completions; } int n = ibv_poll_cq( c.completion_queue, num_completions - completions, work_completions + completions); completions += n; } return completions; } inline int poll( std::span connections_1, std::span connections_2, int num_completions, ibv_wc* work_completions) { int completions = 0; completions += poll(connections_1, num_completions, work_completions); completions += poll( connections_2, num_completions - completions, work_completions + completions); return completions; } /** * Implement a TCP side channel to exchange information about the RDMA * connections. * * Implements a simple all gather where every node sends to rank 0 and rank 0 * broadcasts to every node. */ class SideChannel { public: SideChannel(int rank, int size, const char* addr); SideChannel(SideChannel&& sc); SideChannel(const SideChannel&) = delete; SideChannel& operator=(const SideChannel&) = delete; template std::vector all_gather(const T& v) { std::vector result(size_); // T is a container of stuff like std::vector or std::string if constexpr (is_container::value) { using U = typename T::value_type; // Share the lengths first and set the communication size to be the // maximum length of the containers. auto lengths = all_gather(v.size()); auto max_len = *std::max_element(lengths.begin(), lengths.end()); for (auto& s : result) { s.resize(max_len); } // All gather of length max_len if (rank_ == 0) { std::copy(v.begin(), v.end(), result[rank_].begin()); for (int i = 1; i < size_; i++) { sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); } for (int i = 1; i < size_; i++) { for (int j = 0; j < size_; j++) { sockets_[i - 1].send( IBV_TAG, result[j].data(), sizeof(U) * max_len); } } } else { std::copy(v.begin(), v.end(), result[rank_].begin()); sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len); for (int i = 0; i < size_; i++) { sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); } } // Resize the outputs back to the original length for (int i = 0; i < size_; i++) { result[i].resize(lengths[i]); } } // T is a scalar else { if (rank_ == 0) { result[rank_] = v; for (int i = 1; i < size_; i++) { sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T)); } for (int i = 1; i < size_; i++) { sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T)); } } else { sockets_[0].send(IBV_TAG, &v, sizeof(T)); sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T)); } } return result; } private: int rank_; int size_; std::vector sockets_; }; } // namespace mlx::core::distributed::jaccl ================================================ FILE: mlx/distributed/mpi/CMakeLists.txt ================================================ if(MLX_BUILD_CPU) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) endif() ================================================ FILE: mlx/distributed/mpi/mpi.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi_declarations.h" #define LOAD_SYMBOL(symbol, variable) \ { \ variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \ char* error = dlerror(); \ if (error != nullptr) { \ libmpi_handle_ = nullptr; \ return; \ } \ } static const char* get_libmpi_name() { const char* libname = std::getenv("MLX_MPI_LIBNAME"); if (libname != nullptr) { return libname; } #ifdef __APPLE__ return "libmpi.dylib"; #else return "libmpi.so"; #endif } namespace mlx::core::distributed::mpi { using GroupImpl = mlx::core::distributed::detail::GroupImpl; namespace { template void simple_sum( void* input, void* accumulator, int* len, MPI_Datatype* datatype) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; while (N-- > 0) { *acc += *in; acc++; in++; } } template void simple_sum(void*, void*, int*, MPI_Datatype*); template void simple_sum(void*, void*, int*, MPI_Datatype*); template void simple_max( void* input, void* accumulator, int* len, MPI_Datatype* datatype) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; while (N-- > 0) { *acc = std::max(*acc, *in); acc++; in++; } } template void simple_max(void*, void*, int*, MPI_Datatype*); template void simple_max(void*, void*, int*, MPI_Datatype*); template void simple_max(void*, void*, int*, MPI_Datatype*); template void simple_min( void* input, void* accumulator, int* len, MPI_Datatype* datatype) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; while (N-- > 0) { *acc = std::min(*acc, *in); acc++; in++; } } template void simple_min(void*, void*, int*, MPI_Datatype*); template void simple_min(void*, void*, int*, MPI_Datatype*); template void simple_min(void*, void*, int*, MPI_Datatype*); struct MPIWrapper { MPIWrapper() { initialized_ = false; libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; } // Check library version and warn if it isn't Open MPI int (*get_version)(char*, int*); LOAD_SYMBOL(MPI_Get_library_version, get_version); char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING]; int version_length = 0; get_version(version_ptr, &version_length); std::string_view version(version_ptr, version_length); if (version.find("Open MPI") == std::string::npos) { std::cerr << "[mpi] MPI found but it does not appear to be Open MPI." << "MLX requires Open MPI but this is " << version << std::endl; libmpi_handle_ = nullptr; return; } // API LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Finalize, finalize); LOAD_SYMBOL(MPI_Comm_rank, rank); LOAD_SYMBOL(MPI_Comm_size, size); LOAD_SYMBOL(MPI_Comm_split, comm_split); LOAD_SYMBOL(MPI_Comm_free, comm_free); LOAD_SYMBOL(MPI_Allreduce, all_reduce); LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Recv, recv); LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); LOAD_SYMBOL(MPI_Op_create, mpi_op_create); // Objects LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); // Ops LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); LOAD_SYMBOL(ompi_mpi_op_max, op_max_); LOAD_SYMBOL(ompi_mpi_op_min, op_min_); // Datatypes LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_); LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_); LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_); LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_); LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_); LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_); LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_); LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_); LOAD_SYMBOL(ompi_mpi_float, mpi_float_); LOAD_SYMBOL(ompi_mpi_double, mpi_double_); LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_); } bool is_available() { return libmpi_handle_ != nullptr; } bool init_safe() { if (!is_available()) { return false; } bool success = init(nullptr, nullptr) == MPI_SUCCESS; // Initialize custom types and ops if (success && !initialized_) { // Custom float16 dtypes mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_); mpi_type_commit(&mpi_float16_); mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); mpi_type_commit(&mpi_bfloat16_); // Custom reduction ops mpi_op_create(&simple_sum, 1, &op_sum_f16_); mpi_op_create(&simple_sum, 1, &op_sum_bf16_); mpi_op_create(&simple_max, 1, &op_max_f16_); mpi_op_create(&simple_max, 1, &op_max_bf16_); mpi_op_create(&simple_max, 1, &op_max_c64_); mpi_op_create(&simple_min, 1, &op_min_f16_); mpi_op_create(&simple_min, 1, &op_min_bf16_); mpi_op_create(&simple_min, 1, &op_min_c64_); initialized_ = true; } return success; } void finalize_safe() { if (is_available()) { finalize(); } } MPI_Comm world() { return comm_world_; } MPI_Datatype datatype(const array& arr) { switch (arr.dtype()) { case bool_: return mpi_bool_; case int8: return mpi_int8_; case uint8: return mpi_uint8_; case int16: return mpi_int16_; case uint16: return mpi_uint16_; case int32: return mpi_int32_; case uint32: return mpi_uint32_; case int64: return mpi_int64_; case uint64: return mpi_uint64_; case float32: return mpi_float_; case complex64: return mpi_complex_; case float16: return mpi_float16_; case bfloat16: return mpi_bfloat16_; case float64: return mpi_double_; default: throw std::runtime_error("Invalid type"); } } MPI_Op op_sum(const array& arr) { switch (arr.dtype()) { case float16: return op_sum_f16_; case bfloat16: return op_sum_bf16_; default: return op_sum_; } } MPI_Op op_max(const array& arr) { switch (arr.dtype()) { case float16: return op_max_f16_; case bfloat16: return op_max_bf16_; case complex64: return op_max_c64_; default: return op_max_; } } MPI_Op op_min(const array& arr) { switch (arr.dtype()) { case float16: return op_min_f16_; case bfloat16: return op_min_bf16_; case complex64: return op_min_c64_; default: return op_min_; } } void* libmpi_handle_; // API int (*init)(int*, char***); int (*finalize)(); int (*rank)(MPI_Comm, int*); int (*size)(MPI_Comm, int*); int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm); int (*all_gather)( const void*, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm); int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); int (*comm_free)(MPI_Comm*); int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); // Objects MPI_Comm comm_world_; // Ops MPI_Op op_sum_; MPI_Op op_sum_f16_; MPI_Op op_sum_bf16_; MPI_Op op_max_; MPI_Op op_max_f16_; MPI_Op op_max_bf16_; MPI_Op op_max_c64_; MPI_Op op_min_; MPI_Op op_min_f16_; MPI_Op op_min_bf16_; MPI_Op op_min_c64_; // Datatypes MPI_Datatype mpi_bool_; MPI_Datatype mpi_int8_; MPI_Datatype mpi_uint8_; MPI_Datatype mpi_int16_; MPI_Datatype mpi_uint16_; MPI_Datatype mpi_int32_; MPI_Datatype mpi_uint32_; MPI_Datatype mpi_int64_; MPI_Datatype mpi_uint64_; MPI_Datatype mpi_float_; MPI_Datatype mpi_double_; MPI_Datatype mpi_complex_; MPI_Datatype mpi_float16_; MPI_Datatype mpi_bfloat16_; private: bool initialized_; // Private API int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*); int (*mpi_type_commit)(MPI_Datatype*); int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*); }; MPIWrapper& mpi() { static MPIWrapper wrapper; return wrapper; } } // namespace class MPIGroup : public GroupImpl { public: MPIGroup(MPI_Comm comm, bool global) : comm_(comm), global_(global), rank_(-1), size_(-1) {} virtual ~MPIGroup() { if (global_) { mpi().finalize_safe(); } else { mpi().comm_free(&comm_); } } Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::cpu); } int rank() override { if (rank_ < 0) { mpi().rank(comm_, &rank_); } return rank_; } int size() override { if (size_ < 0) { mpi().size(comm_, &size_); } return size_; } std::shared_ptr split(int color, int key = -1) override { key = (key < 0) ? rank() : key; MPI_Comm new_comm; int result = mpi().comm_split(comm_, color, key, &new_comm); if (result != MPI_SUCCESS) { throw std::runtime_error("MPI could not split this group"); } return std::make_shared(new_comm, false); } void all_sum(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch( mpi().all_reduce, (input.data() == output.data()) ? MPI_IN_PLACE : input.data(), output.data(), input.size(), mpi().datatype(input), mpi().op_sum(input), comm_); } void all_max(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch( mpi().all_reduce, (input.data() == output.data()) ? MPI_IN_PLACE : input.data(), output.data(), input.size(), mpi().datatype(input), mpi().op_max(input), comm_); } void all_min(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch( mpi().all_reduce, (input.data() == output.data()) ? MPI_IN_PLACE : input.data(), output.data(), input.size(), mpi().datatype(input), mpi().op_min(input), comm_); } void all_gather(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch( mpi().all_gather, input.data(), input.size(), mpi().datatype(input), output.data(), input.size(), mpi().datatype(output), comm_); } void send(const array& input, int dst, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.dispatch( mpi().send, input.data(), input.size(), mpi().datatype(input), dst, 0, comm_); } void recv(array& out, int src, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); encoder.dispatch([out_ptr = out.data(), out_size = out.size(), out_type = mpi().datatype(out), src, comm = comm_]() { MPI_Status status; mpi().recv(out_ptr, out_size, out_type, src, MPI_ANY_TAG, comm, &status); }); } void sum_scatter(const array& input, array& output, Stream stream) override { throw std::runtime_error("[mpi] sum_scatter not yet implemented."); } private: MPI_Comm comm_; bool global_; int rank_; int size_; }; bool is_available() { return mpi().is_available(); } std::shared_ptr init(bool strict /* = false */) { if (!mpi().init_safe()) { if (strict) { throw std::runtime_error("Cannot initialize MPI"); } return nullptr; } return std::make_shared(mpi().world(), true); } } // namespace mlx::core::distributed::mpi ================================================ FILE: mlx/distributed/mpi/mpi.h ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::mpi { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available(); std::shared_ptr init(bool strict = false); } // namespace mlx::core::distributed::mpi ================================================ FILE: mlx/distributed/mpi/mpi_declarations.h ================================================ // Copyright © 2024 Apple Inc. // Constants #define MPI_SUCCESS 0 #define MPI_ANY_SOURCE -1 #define MPI_ANY_TAG -1 #define MPI_IN_PLACE ((void*)1) #define MPI_MAX_LIBRARY_VERSION_STRING 256 // Define all the types that we use so that we don't include which // causes linker errors on some platforms. // // NOTE: We define everything for openmpi. typedef void* MPI_Comm; typedef void* MPI_Datatype; typedef void* MPI_Op; typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); typedef struct ompi_status_public_t { int MPI_SOURCE; int MPI_TAG; int MPI_ERROR; int _cancelled; size_t _ucount; } MPI_Status; ================================================ FILE: mlx/distributed/mpi/no_mpi.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/mpi/mpi.h" namespace mlx::core::distributed::mpi { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available() { return false; } std::shared_ptr init(bool strict /* = false */) { if (strict) { throw std::runtime_error("Cannot initialize MPI"); } return nullptr; } } // namespace mlx::core::distributed::mpi ================================================ FILE: mlx/distributed/nccl/CMakeLists.txt ================================================ if(MLX_BUILD_CUDA AND NOT WIN32) find_package(NCCL) if(NCCL_FOUND) target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES}) target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS}) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp) endif() else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp) endif() ================================================ FILE: mlx/distributed/nccl/nccl.cpp ================================================ // NCCL distributed support currently requires Unix socket APIs // TODO: Add Windows Winsock2 support for Windows builds #ifndef _WIN32 #include #include #include #include #endif #include #include #include #include #include #include #include #include #include #include #include "mlx/backend/cuda/device.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/dtype_utils.h" #include "mlx/utils.h" namespace mlx::core::distributed::nccl { // Can be tuned with MLX_NCCL_TIMEOUT constexpr int nccl_timeout = 300000; // miliseconds #define CHECK_CUDA(cmd) \ do { \ cudaError_t e = cmd; \ if (e != cudaSuccess) { \ fprintf( \ stderr, \ "CUDA error %s:%d '%s'\n", \ __FILE__, \ __LINE__, \ cudaGetErrorString(e)); \ exit(1); \ } \ } while (0) #define CHECK_NCCL(cmd) \ do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ fprintf( \ stderr, \ "NCCL error %s:%d '%s'\n", \ __FILE__, \ __LINE__, \ ncclGetErrorString(r)); \ exit(1); \ } \ } while (0) #define MLX_NCCL_TYPE_LIST(X) \ X(int8_t, ncclChar) \ X(uint8_t, ncclUint8) \ X(int32_t, ncclInt) \ X(uint32_t, ncclUint32) \ X(int64_t, ncclInt64) \ X(uint64_t, ncclUint64) \ X(float16_t, ncclHalf) \ X(bfloat16_t, ncclBfloat16) \ X(float, ncclFloat) \ X(double, ncclDouble) template struct nccl_map { static constexpr bool ok = false; // default: unsupported }; #define MLX_DEF_NCCL_MAP(T, E) \ template <> \ struct nccl_map { \ static constexpr bool ok = true; \ static constexpr ncclDataType_t value = E; \ }; MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP) #undef MLX_DEF_NCCL_MAP namespace detail { template void dispatch_dtype(const array& arr, F&& f) { dispatch_all_types(arr.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); if constexpr (nccl_map::ok) { f(type_tag, nccl_map::value); } else { throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); } }); } #ifndef _WIN32 inline void sendAll(int sock, const void* buf, size_t len) { const char* ptr = reinterpret_cast(buf); while (len > 0) { ssize_t sent = send(sock, ptr, len, 0); if (sent <= 0) { perror("send"); exit(1); } ptr += sent; len -= sent; } } inline void recvAll(int sock, void* buf, size_t len) { char* ptr = reinterpret_cast(buf); while (len > 0) { ssize_t rec = recv(sock, ptr, len, 0); if (rec <= 0) { perror("recv"); exit(1); } ptr += rec; len -= rec; } } #endif // _WIN32 #ifndef _WIN32 inline void bootstrap_unique_id( ncclUniqueId& id, int rank, int size, const std::string& initMethod) { // Parse the init method to extract the host and port if (initMethod.rfind("tcp://", 0) != 0) throw; auto hostport = initMethod.substr(6); auto colon = hostport.find(':'); std::string host = hostport.substr(0, colon); int port = std::stoi(hostport.substr(colon + 1)); if (rank == 0) { // create a unique id on the rank 0 CHECK_NCCL(ncclGetUniqueId(&id)); // create a socket to send the unique id to all other ranks int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { std::ostringstream msg; msg << "[nccl] Couldn't create socket (error: " << errno << ")"; throw std::runtime_error(msg.str()); } sockaddr_in serv = {}; serv.sin_family = AF_INET; serv.sin_addr.s_addr = htonl(INADDR_ANY); serv.sin_port = htons(port); int reuse = 1; // Without this, if rank-0 crashes or restarts process quickly, // the OS might refuse to let binding to the same port, so reuse if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { std::ostringstream msg; msg << "[nccl] setsockopt() failed: " << strerror(errno); throw std::runtime_error(msg.str()); } if (bind(sock, reinterpret_cast(&serv), sizeof(serv)) < 0) { std::ostringstream msg; msg << "[nccl] bind() failed: " << strerror(errno); throw std::runtime_error(msg.str()); } if (listen(sock, size - 1) < 0) { std::ostringstream msg; msg << "[nccl] listen() failed: " << strerror(errno); throw std::runtime_error(msg.str()); } for (int peer = 1; peer < size; ++peer) { int conn = accept(sock, nullptr, nullptr); if (conn < 0) { std::ostringstream msg; msg << "[nccl] accept() failed: " << strerror(errno); throw std::runtime_error(msg.str()); } sendAll(conn, &id, sizeof(id)); close(conn); } close(sock); } else { // Here we want to make sure that rank 0 has enough time to bind // so we will retry to connect until elapsed time exceeds nccl_timeout // this is particularity important for multinode setup int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { std::ostringstream msg; msg << "[nccl] socket() failed: " << strerror(errno); throw std::runtime_error(msg.str()); } hostent* he = gethostbyname(host.c_str()); if (!he) { throw std::runtime_error("[nccl] lookup failed for host: " + host); } sockaddr_in serv = {}; serv.sin_family = AF_INET; memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length); serv.sin_port = htons(port); const int timeout_ms = env::nccl_timeout(nccl_timeout); bool connected = false; const char* dbg = std::getenv("NCCL_DEBUG"); bool do_log = (dbg && std::string(dbg) == "INFO"); auto start = std::chrono::steady_clock::now(); int attempt = 0; while (true) { auto elapsed_ms = std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); if (elapsed_ms > timeout_ms) break; if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == 0) { connected = true; if (do_log) { std::cout << "[Rank " << rank << "] Connected successfully after " << elapsed_ms << " miliseconds" << std::endl; break; } } if (errno != ECONNREFUSED) { break; } ++attempt; std::this_thread::sleep_for(std::chrono::milliseconds(500)); } if (!connected) { std::ostringstream msg; msg << "[Rank " << rank << "] connect() failed after " << timeout_ms << " milliseconds and " << attempt << " retries: " << strerror(errno); close(sock); throw std::runtime_error(msg.str()); } recvAll(sock, &id, sizeof(id)); close(sock); } } #else // _WIN32 inline void bootstrap_unique_id( ncclUniqueId& id, int rank, int size, const std::string& initMethod) { throw std::runtime_error( "[nccl] Distributed NCCL is not yet supported on Windows"); } #endif // _WIN32 } // namespace detail // helper struct to manage communicator struct NCCLComm { ncclComm_t comm; int rank_; int size_; NCCLComm(ncclComm_t c, int rank, int size) : comm(c), rank_(rank), size_(size) {} static std::shared_ptr create(int numRanks, int rank, ncclUniqueId commId) { ncclComm_t raw; CHECK_NCCL(ncclCommInitRank(&raw, numRanks, commId, rank)); return std::make_shared(raw, rank, numRanks); } static std::shared_ptr split(NCCLComm* source, int color, int key) { ncclComm_t raw; // default config, blocking comm creation ncclConfig_t config = NCCL_CONFIG_INITIALIZER; CHECK_NCCL(ncclCommSplit(source->comm, color, key, &raw, &config)); int new_rank, new_size; CHECK_NCCL(ncclCommUserRank(raw, &new_rank)); CHECK_NCCL(ncclCommCount(raw, &new_size)); return std::make_shared(raw, new_rank, new_size); } NCCLComm(const NCCLComm&) = delete; NCCLComm& operator=(const NCCLComm&) = delete; }; using GroupImpl = mlx::core::distributed::detail::GroupImpl; class NCCLGroup : public GroupImpl { public: NCCLGroup(int worldRank, int worldSize, const std::string initMethod) : rank_(worldRank), size_(worldSize), initMethod_(initMethod) { if (initialized_) return; int ndev; CHECK_CUDA(cudaGetDeviceCount(&ndev)); CHECK_CUDA(cudaSetDevice(rank_ % ndev)); detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_); comm_ = NCCLComm::create(size_, rank_, uniqueId_); initialized_ = true; } // Used by split() to wrap an already-created communicator NCCLGroup(std::shared_ptr comm, int rank, int size) : rank_(rank), size_(size), comm_(std::move(comm)) {} Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::gpu); } int rank() override { return rank_; } int size() override { return size_; } void all_sum(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclSum); }); } std::shared_ptr split(int color, int key = -1) override { key = (key < 0) ? rank() : key; auto new_comm = NCCLComm::split(comm_.get(), color, key); return std::make_shared( new_comm, new_comm->rank_, new_comm->size_); } void all_gather(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclAllGather( gpu_ptr(input), gpu_ptr(output), input.size(), dt, comm_->comm, encoder.stream())); }); } void send(const array& input, int dst, Stream stream) override { throw std::runtime_error("[nccl] Send not supported in NCCL backend."); } void recv(array& output, int src, Stream stream) override { throw std::runtime_error("[nccl] Recv not supported in NCCL backend."); } void all_max(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMax); }); } void all_min(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMin); }); } void sum_scatter(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; reduce_scatter_impl(input, output, stream, dt, ncclSum); }); } template void all_reduce_impl( const array& input, array& output, Stream stream, ncclDataType_t dt, ncclRedOp_t op) { auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclAllReduce( gpu_ptr(input), gpu_ptr(output), input.size(), dt, op, comm_->comm, encoder.stream())); } template void reduce_scatter_impl( const array& input, array& output, Stream stream, ncclDataType_t dt, ncclRedOp_t op) { auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclReduceScatter( gpu_ptr(input), gpu_ptr(output), output.size(), dt, op, comm_->comm, encoder.stream())); } int rank_; int size_; std::string initMethod_; ncclUniqueId uniqueId_; std::shared_ptr comm_; bool initialized_ = false; }; bool is_available() { return true; } namespace detail { std::string get_env_var_or_throw(const char* env_var_name, bool strict) { const char* value = std::getenv(env_var_name); if (value == nullptr && strict) { std::ostringstream msg; msg << "[nccl] Required environment variable '" << env_var_name << "' is not set. " << "Please set it before initializing the distributed backend."; throw std::runtime_error(msg.str()); } if (value == nullptr) { return ""; } return std::string(value); } } // namespace detail std::shared_ptr init(bool strict /* = false */) { std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict); std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict); std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict); std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict); if (!strict && (host.empty() || port.empty() || rank_str.empty() || n_nodes_str.empty())) { return nullptr; } int rank = std::stoi(rank_str); int n_nodes = std::stoi(n_nodes_str); std::string init_method = "tcp://" + host + ":" + port; return std::make_shared(rank, n_nodes, init_method); } } // namespace mlx::core::distributed::nccl ================================================ FILE: mlx/distributed/nccl/nccl.h ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::nccl { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available(); std::shared_ptr init(bool strict = false); } // namespace mlx::core::distributed::nccl ================================================ FILE: mlx/distributed/nccl/no_nccl.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/nccl/nccl.h" namespace mlx::core::distributed::nccl { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available() { return false; } std::shared_ptr init(bool strict /* = false */) { if (strict) { throw std::runtime_error("Cannot initialize nccl distributed backend."); } return nullptr; } } // namespace mlx::core::distributed::nccl ================================================ FILE: mlx/distributed/ops.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" namespace mlx::core::distributed { namespace { Group to_group(std::optional group) { if (group.has_value()) { return group.value(); } else { return distributed::init(); } } } // namespace array all_sum( const array& x, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { return x; } auto stream = detail::communication_stream(group, s); return array( x.shape(), x.dtype(), std::make_shared(stream, group, AllReduce::Sum), {x}); } array all_max( const array& x, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { return x; } auto stream = detail::communication_stream(group, s); return array( x.shape(), x.dtype(), std::make_shared(stream, group, AllReduce::Max), {x}); } array all_min( const array& x, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { return x; } auto stream = detail::communication_stream(group, s); return array( x.shape(), x.dtype(), std::make_shared(stream, group, AllReduce::Min), {x}); } array all_gather( const array& x, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { return x; } auto stream = detail::communication_stream(group, s); auto result_shape = x.shape(); if (result_shape.size() == 0) { result_shape.push_back(group.size()); } else { result_shape[0] *= group.size(); } return array( std::move(result_shape), x.dtype(), std::make_shared(stream, group), {x}); } array send( const array& x, int dst, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { throw std::invalid_argument("Cannot send to a singleton group"); } auto stream = detail::communication_stream(group, s); if (dst < 0 || dst >= group.size()) { std::ostringstream msg; msg << "Invalid destination=" << dst << " for a group of size " << group.size(); throw std::invalid_argument(msg.str()); } return array( x.shape(), x.dtype(), std::make_shared(stream, group, dst), {x}); } array recv( Shape shape, Dtype dtype, int src, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { throw std::invalid_argument("Cannot recv from a singleton group"); } auto stream = detail::communication_stream(group, s); if (src < 0 || src >= group.size()) { std::ostringstream msg; msg << "Invalid source=" << src << " for a group of size " << group.size(); throw std::invalid_argument(msg.str()); } return array( std::move(shape), std::move(dtype), std::make_shared(stream, group, src), std::vector{}); } array recv_like( const array& x, int src, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { return recv(x.shape(), x.dtype(), src, group_, s); } array sum_scatter( const array& x, std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { return x; } if (x.shape()[0] % group.size() != 0) { std::ostringstream msg; msg << "[sum_scatter] Invalid shape=" << x.shape() << " for a group of size " << group.size() << ". The first dimension (axis 0) must be divisible by the group size."; throw std::invalid_argument(msg.str()); } auto result_shape = x.shape(); result_shape[0] /= group.size(); auto stream = detail::communication_stream(group, s); return array( std::move(result_shape), x.dtype(), std::make_shared(stream, group, ReduceScatter::Sum), {x}); } } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/ops.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/api.h" #include "mlx/distributed/distributed.h" #include "mlx/utils.h" namespace mlx::core::distributed { MLX_API array all_sum( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array all_gather( const array& x, std::optional group = std::nullopt, StreamOrDevice S = {}); MLX_API array send( const array& x, int dst, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array recv( Shape shape, Dtype dtype, int src, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array recv_like( const array& x, int src, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array all_max( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array all_min( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); MLX_API array sum_scatter( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/primitives.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/ops.h" namespace mlx::core::distributed { std::pair, std::vector> AllReduce::vmap( const std::vector& inputs, const std::vector& axes) { switch (reduce_type_) { case Sum: return {{all_sum(inputs[0], group(), stream())}, axes}; case Max: return {{all_max(inputs[0], group(), stream())}, axes}; case Min: return {{all_min(inputs[0], group(), stream())}, axes}; default: throw std::runtime_error( "Only all reduce sum, max and min are supported for now"); } } std::vector AllReduce::jvp( const std::vector& primals, const std::vector& tangents, const std::vector&) { switch (reduce_type_) { case Sum: return {all_sum(tangents[0], group(), stream())}; case Max: return {all_max(tangents[0], group(), stream())}; case Min: return {all_min(tangents[0], group(), stream())}; default: throw std::runtime_error( "Only all reduce sum, max and min are supported for now"); } } std::vector AllReduce::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector&, const std::vector& outputs) { return cotangents; } std::pair, std::vector> AllGather::vmap( const std::vector& inputs, const std::vector& axes) { return {{all_gather(inputs[0], group(), stream())}, axes}; } std::vector AllGather::jvp( const std::vector& primals, const std::vector& tangents, const std::vector&) { return {all_gather(tangents[0], group(), stream())}; } std::vector AllGather::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector&, const std::vector&) { auto g = group(); auto ndim = primals[0].ndim(); Shape starts(primals[0].ndim(), 0); auto stops = primals[0].shape(); if (ndim == 0) { starts.push_back(0); stops.push_back(1); } starts[0] = g.rank() * stops[0]; stops[0] += starts[0]; auto out = slice(cotangents[0], starts, stops); if (ndim == 0) { out = squeeze(out, 0); } return {out}; } std::pair, std::vector> Send::vmap( const std::vector& inputs, const std::vector& axes) { return {{send(inputs[0], dst_, group(), stream())}, axes}; } } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/primitives.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/primitives.h" namespace mlx::core::distributed { class DistPrimitive : public Primitive { public: DistPrimitive(Stream stream, Group group) : Primitive(stream), group_(group) {} const Group& group() const { return group_; } private: Group group_; }; class AllReduce : public DistPrimitive { public: enum ReduceType { And, Or, Sum, Prod, Min, Max }; AllReduce(Stream stream, Group group, ReduceType reduce_type) : DistPrimitive(stream, group), reduce_type_(reduce_type) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; const char* name() const override { switch (reduce_type_) { case And: return "And AllReduce"; case Or: return "Or AllReduce"; case Sum: return "Sum AllReduce"; case Prod: return "Prod AllReduce"; case Min: return "Min AllReduce"; case Max: return "Max AllReduce"; } return ""; } private: ReduceType reduce_type_; }; class AllGather : public DistPrimitive { public: AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; DEFINE_NAME(AllGather); }; class Send : public DistPrimitive { public: Send(Stream stream, Group group, int dst) : DistPrimitive(stream, group), dst_(dst) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_NAME(Send); private: int dst_; }; class Recv : public DistPrimitive { public: Recv(Stream stream, Group group, int src) : DistPrimitive(stream, group), src_(src) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_NAME(Recv); private: int src_; }; class ReduceScatter : public DistPrimitive { public: enum ReduceType { Sum, Min, Max }; ReduceScatter(Stream stream, Group group, ReduceType reduce_type) : DistPrimitive(stream, group), reduce_type_(reduce_type) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; const char* name() const override { switch (reduce_type_) { case Sum: return "Sum ReduceScatter"; case Min: return "Min ReduceScatter"; case Max: return "Max ReduceScatter"; } return ""; } private: ReduceType reduce_type_; }; } // namespace mlx::core::distributed ================================================ FILE: mlx/distributed/reduction_ops.h ================================================ // Copyright © 2025 Apple Inc. namespace mlx::core::distributed::detail { template struct SumOp { void operator()(const T* input, T* output, size_t N) const { while (N-- > 0) { *output += *input; input++; output++; } } }; template struct MaxOp { void operator()(const T* input, T* output, size_t N) const { while (N-- > 0) { *output = std::max(*output, *input); input++; output++; } } }; template struct MinOp { void operator()(const T* input, T* output, size_t N) const { while (N-- > 0) { *output = std::min(*output, *input); input++; output++; } } }; } // namespace mlx::core::distributed::detail ================================================ FILE: mlx/distributed/ring/CMakeLists.txt ================================================ if(MLX_BUILD_CPU AND NOT WIN32) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp) endif() ================================================ FILE: mlx/distributed/ring/no_ring.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/ring/ring.h" namespace mlx::core::distributed::ring { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available() { return false; } std::shared_ptr init(bool strict /* = false */) { if (strict) { throw std::runtime_error("Cannot initialize ring distributed backend."); } return nullptr; } } // namespace mlx::core::distributed::ring ================================================ FILE: mlx/distributed/ring/ring.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include #include #include #include #include #include #include #include #include #include #include #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/reduction_ops.h" #include "mlx/distributed/utils.h" #include "mlx/threadpool.h" #ifndef SOL_TCP #define SOL_TCP IPPROTO_TCP #endif #define SWITCH_TYPE(x, ...) \ switch ((x).dtype()) { \ case bool_: { \ using T = bool; \ __VA_ARGS__; \ } break; \ case int8: { \ using T = int8_t; \ __VA_ARGS__; \ } break; \ case int16: { \ using T = int16_t; \ __VA_ARGS__; \ } break; \ case int32: { \ using T = int32_t; \ __VA_ARGS__; \ } break; \ case int64: { \ using T = int64_t; \ __VA_ARGS__; \ } break; \ case uint8: { \ using T = uint8_t; \ __VA_ARGS__; \ } break; \ case uint16: { \ using T = uint16_t; \ __VA_ARGS__; \ } break; \ case uint32: { \ using T = uint32_t; \ __VA_ARGS__; \ } break; \ case uint64: { \ using T = uint64_t; \ __VA_ARGS__; \ } break; \ case bfloat16: { \ using T = bfloat16_t; \ __VA_ARGS__; \ } break; \ case float16: { \ using T = float16_t; \ __VA_ARGS__; \ } break; \ case float32: { \ using T = float; \ __VA_ARGS__; \ } break; \ case float64: { \ using T = double; \ __VA_ARGS__; \ } break; \ case complex64: { \ using T = complex64_t; \ __VA_ARGS__; \ } break; \ } namespace mlx::core::distributed::ring { constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; constexpr const size_t ALL_SUM_BUFFERS = 2; constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_WAIT = 1000; constexpr const char* RING_TAG = "[ring]"; using GroupImpl = mlx::core::distributed::detail::GroupImpl; using json = nlohmann::json; using namespace std::chrono_literals; namespace { template void log(std::ostream& os, T first) { os << first << std::endl; } template void log(std::ostream& os, T first, Args... args) { log(os << first << " ", args...); } template void log_info(bool verbose, Args... args) { if (!verbose) { return; } log(std::cerr, "[ring]", args...); } template decltype(T() * U()) ceildiv(T a, U b) { return (a + b - 1) / b; } class SocketThread { public: SocketThread(int fd) : fd_(fd), stop_(false) { worker_ = std::thread(&SocketThread::worker, this); int flags = fcntl(fd, F_GETFL, 0); fcntl(fd, F_SETFL, flags | O_NONBLOCK); } ~SocketThread() { stop_ = true; condition_.notify_all(); worker_.join(); int flags = fcntl(fd_, F_GETFL, 0); fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK); } template std::future send(const T* buffer, size_t size) { return send_impl(reinterpret_cast(buffer), size * sizeof(T)); } template std::future recv(T* buffer, size_t size) { return recv_impl(reinterpret_cast(buffer), size * sizeof(T)); } private: struct SocketTask { SocketTask(void* b, size_t s, std::promise&& p) : buffer(b), size(s), promise(std::move(p)) {} SocketTask(SocketTask&& t) : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {} void* buffer; size_t size; std::promise promise; }; std::future send_impl(const char* buffer, size_t size) { std::promise send_completed_promise; auto send_completed_future = send_completed_promise.get_future(); if (size == 0) { send_completed_promise.set_value(); return send_completed_future; } { std::unique_lock lock(queue_mutex_); sends_.emplace_back(SocketTask( const_cast(buffer), size, std::move(send_completed_promise))); } condition_.notify_one(); return send_completed_future; } std::future recv_impl(char* buffer, size_t size) { std::promise recv_completed_promise; auto recv_completed_future = recv_completed_promise.get_future(); if (size == 0) { recv_completed_promise.set_value(); return recv_completed_future; } { std::unique_lock lock(queue_mutex_); recvs_.emplace_back( SocketTask(buffer, size, std::move(recv_completed_promise))); } condition_.notify_one(); return recv_completed_future; } bool have_tasks() { return !(sends_.empty() && recvs_.empty()); } void worker() { int error_count = 0; bool delete_recv = false; bool delete_send = false; while (true) { { std::unique_lock lock(queue_mutex_); if (delete_recv) { recvs_.front().promise.set_value(); recvs_.pop_front(); delete_recv = false; } if (delete_send) { sends_.front().promise.set_value(); sends_.pop_front(); delete_send = false; } if (stop_) { return; } if (!have_tasks()) { condition_.wait(lock, [this] { return stop_ || have_tasks(); }); if (stop_) { return; } } } if (!recvs_.empty()) { auto& task = recvs_.front(); ssize_t r = ::recv(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_recv = task.size == 0; error_count = 0; } else if (errno != EAGAIN) { error_count++; log_info( true, "Receiving from socket", fd_, "failed with errno", errno); } } if (!sends_.empty()) { auto& task = sends_.front(); ssize_t r = ::send(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_send = task.size == 0; error_count = 0; } else if (errno != EAGAIN) { error_count++; log_info(true, "Sending to socket", fd_, "failed with errno", errno); } } if (error_count >= 10) { log_info(true, "Too many send/recv errors. Aborting..."); return; } } } int fd_; bool stop_; std::thread worker_; std::mutex queue_mutex_; std::condition_variable condition_; std::list sends_; std::list recvs_; }; class CommunicationThreads { public: void add(const std::vector& sockets) { for (int sock : sockets) { threads_.emplace(sock, sock); } } template std::future send(int socket, T* buffer, size_t size) { return threads_.at(socket).send(buffer, size); } template std::future recv(int socket, T* buffer, size_t size) { return threads_.at(socket).recv(buffer, size); } private: std::unordered_map threads_; }; /** * Load all addresses from the json hostfile. The hostfile is a list of * addresses in order of rank. For each rank there can be many addresses so * that we can have multiple connections between peers. * * For example: * [ * ["ip1:5000", "ip1:5001"], * ["ip2:5000", "ip2:5001"], * ["ip3:5000", "ip3:5001"], * ] */ std::vector> load_nodes(const char* hostfile) { std::vector> nodes; std::ifstream f(hostfile); json hosts = json::parse(f); for (auto& h : hosts) { std::vector host; for (auto& ips : h) { host.push_back(std::move(detail::parse_address(ips.get()))); } nodes.push_back(std::move(host)); } return nodes; } /** * Create a socket and accept one connection for each of the provided * addresses. */ std::vector accept_connections( const std::vector& addresses) { std::vector sockets; int success; for (auto& address : addresses) { detail::TCPSocket socket(RING_TAG); socket.listen(RING_TAG, address); sockets.push_back(socket.accept(RING_TAG).detach()); } return sockets; } /** * The counterpoint of `accept_connections`. Basically connect to each of the * provided addresses. */ std::vector make_connections( const std::vector& addresses, bool verbose) { std::vector sockets; int success; for (auto& address : addresses) { sockets.push_back( detail::TCPSocket::connect( RING_TAG, address, CONN_ATTEMPTS, CONN_WAIT, [verbose](int attempt, int wait) { log_info( verbose, "Attempt", attempt, "waiting", wait, "ms (error:", errno, ")"); }) .detach()); } return sockets; } } // namespace class RingGroup : public GroupImpl { public: RingGroup( int rank, std::vector> nodes, bool verbose) : rank_(rank), verbose_(verbose), pool_(0) { if (rank_ > 0 && rank_ >= nodes.size()) { throw std::runtime_error( "[ring] Rank cannot be larger than the size of the group"); } size_ = nodes.size(); int connect_to = (rank_ + 1) % size_; // We define the connection order by having the rank_ == size_ - 1 connect // first and accept after. if (rank_ < connect_to) { log_info(verbose_, "Rank", rank_, "accepting"); sockets_left_ = accept_connections(nodes[rank_]); log_info(verbose_, "Rank", rank_, "connecting to", connect_to); sockets_right_ = make_connections(nodes[connect_to], verbose); } else { log_info(verbose_, "Rank", rank_, "connecting to", connect_to); sockets_right_ = make_connections(nodes[connect_to], verbose); log_info(verbose_, "Rank", rank_, "accepting"); sockets_left_ = accept_connections(nodes[rank_]); } // Failure if we couldn't make right or left sockets if (sockets_right_.empty()) { std::ostringstream msg; msg << "[ring] Rank " << rank_ << " has no sockets to the right."; throw std::invalid_argument(msg.str()); } if (sockets_left_.empty()) { std::ostringstream msg; msg << "[ring] Rank " << rank_ << " has no sockets to the left."; throw std::invalid_argument(msg.str()); } // The following could be relaxed since we can define non-homogeneous rings // but it makes things a bit simpler for now. if (sockets_right_.size() != sockets_left_.size()) { std::ostringstream msg; msg << "[ring] It is required to have as many connections to the left as " << "to the right but rank " << rank_ << " has " << sockets_right_.size() << " connections to the right and " << sockets_left_.size() << " to the left."; throw std::invalid_argument(msg.str()); } // Configure all sockets to use TCP no delay. int one = 1; for (int i = 0; i < sockets_right_.size(); i++) { setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); } // Start the all reduce threads. One all reduce per direction per ring. pool_.resize(sockets_right_.size() + sockets_left_.size()); // Create a communication thread per socket. This also converts them to // non-blocking. comm_.add(sockets_right_); comm_.add(sockets_left_); // Allocate buffers for the all sum buffers_.resize( (sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS * ALL_SUM_SIZE); } ~RingGroup() { for (auto s : sockets_right_) { shutdown(s, 2); close(s); } for (auto s : sockets_left_) { shutdown(s, 2); close(s); } } Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::cpu); } int rank() override { return rank_; } int size() override { return size_; } void all_sum(const array& input, array& output, Stream stream) override { SWITCH_TYPE( output, all_reduce(input, output, stream, detail::SumOp())); } void all_max(const array& input, array& output, Stream stream) override { SWITCH_TYPE( output, all_reduce(input, output, stream, detail::MaxOp())); } void all_min(const array& input, array& output, Stream stream) override { SWITCH_TYPE( output, all_reduce(input, output, stream, detail::MinOp())); } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[ring] Group split not supported."); } void all_gather(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([input_ptr = input.data(), nbytes = input.nbytes(), output_ptr = output.data(), this]() { constexpr size_t min_send_size = 262144; size_t n_gathers = std::max( std::min( sockets_right_.size() + sockets_left_.size(), nbytes / min_send_size), size_t(1)); size_t bytes_per_gather = ceildiv(nbytes, n_gathers); std::vector> all_gathers; for (int i = 0; i < n_gathers; i++) { auto offset = i * bytes_per_gather; all_gathers.emplace_back(pool_.enqueue( std::bind( &RingGroup::all_gather_impl, this, input_ptr + offset, output_ptr + offset, nbytes, offset + bytes_per_gather > nbytes ? nbytes - offset : bytes_per_gather, sockets_right_[i / 2], sockets_left_[i / 2], (i % 2) ? -1 : 1))); } for (auto& f : all_gathers) { f.wait(); } }); } void send(const array& input, int dst, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.dispatch( [input_ptr = input.data(), nbytes = input.nbytes(), dst, this]() { int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (dst == right) { send(sockets_right_, input_ptr, nbytes); } else if (dst == left) { send(sockets_left_, input_ptr, nbytes); } else { std::ostringstream msg; msg << "[ring] Send only supported to direct neighbors " << "but tried to send to " << dst << " from " << rank_ << std::endl; throw std::runtime_error(msg.str()); } }); } void recv(array& out, int src, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); encoder.dispatch( [out_ptr = out.data(), nbytes = out.nbytes(), src, this]() { // NOTE: We 'll check the sockets with the opposite order of send so // that they work even with 2 nodes where left and right is the same // neighbor. int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (src == left) { recv(sockets_left_, out_ptr, nbytes); } else if (src == right) { recv(sockets_right_, out_ptr, nbytes); } else { std::ostringstream msg; msg << "[ring] Recv only supported from direct neighbors " << "but tried to recv from " << src << " to " << rank_ << std::endl; throw std::runtime_error(msg.str()); } }); } void sum_scatter(const array& input, array& output, Stream stream) override { throw std::runtime_error("[ring] sum_scatter not supported."); } private: template void all_reduce( const array& input, array& output, Stream stream, ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. size_t nbytes = size * sizeof(T); if (size < size_) { // TODO: Maybe allocate dynamically so we don't have the constraint // below? if (sizeof(T) * size_ > 1024) { std::ostringstream msg; msg << "Can't perform the ring all reduce of " << size << " elements with a ring of size " << size_; throw std::runtime_error(msg.str()); } char buffer[1024]; std::memset(buffer, 0, size_ * sizeof(T)); std::memcpy(buffer, in_ptr, nbytes); all_reduce_impl( reinterpret_cast(buffers_.data()), reinterpret_cast(buffer), size_, sockets_right_[0], sockets_left_[0], -1, reduce_op); std::memcpy(out_ptr, buffer, nbytes); return; } // If not inplace all reduce then copy the input to the output first if (in_ptr != out_ptr) { std::memcpy(out_ptr, in_ptr, nbytes); } // Split the all reduces so that each member has at least 1 buffer to // send/recv per segment. constexpr size_t min_send_size = 262144; size_t n_reduces = std::max( std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), size_t(1)); size_t step = ceildiv(size, n_reduces); std::vector> all_sums; for (int i = 0; i < n_reduces; i++) { all_sums.emplace_back(pool_.enqueue( std::bind( &RingGroup::all_reduce_impl, this, reinterpret_cast( buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), reinterpret_cast(out_ptr) + i * step, std::min(size, (i + 1) * step) - i * step, sockets_right_[i / 2], sockets_left_[i / 2], (i % 2) ? -1 : 1, reduce_op))); } for (auto& f : all_sums) { f.wait(); } }); } template void all_reduce_impl( T* buffer, T* data, size_t data_size, int socket_right, int socket_left, int direction, ReduceOp reduce_op) { // Choose which socket we send to and recv from int socket_send = (direction < 0) ? socket_right : socket_left; int socket_recv = (direction < 0) ? socket_left : socket_right; // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. size_t segment_size = ceildiv(data_size, size_); size_t BUFFER_SIZE = std::max( size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments int send_segment = rank_; int recv_segment = (rank_ + direction + size_) % size_; // Plan the whole reduce in terms of sends and recvs as indices in data. // It makes the actual async send and recv a bit simpler to follow when // there are less offset calculations around. std::vector> send_plan; std::vector> recv_plan; // Two times the same send/recv operations, first scatter reduce and then // gather. for (int k = 0; k < 2; k++) { for (int i = 0; i < size_ - 1; i++) { size_t send_start = send_segment * segment_size; size_t send_stop = std::min((send_segment + 1) * segment_size, data_size); size_t recv_start = recv_segment * segment_size; size_t recv_stop = std::min((recv_segment + 1) * segment_size, data_size); for (size_t j = 0; j < n_packets; j++) { send_plan.emplace_back( std::min(send_start + j * BUFFER_SIZE, send_stop), std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop)); recv_plan.emplace_back( std::min(recv_start + j * BUFFER_SIZE, recv_stop), std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop)); } send_segment = (send_segment + size_ + direction) % size_; recv_segment = (recv_segment + size_ + direction) % size_; } } // Running the plan is fairly simple, we keep a send and a recv in flight // while doing the summation. T* recv_buffers[ALL_SUM_BUFFERS]; for (int i = 0; i < ALL_SUM_BUFFERS; i++) { recv_buffers[i] = buffer + i * BUFFER_SIZE; } std::future sends[2], recvs[2]; int a = 0; int b = (n_packets > 1) ? 1 : 0; for (int i = 0, j = -b; i < send_plan.size(); j++, i++) { sends[a] = comm_.send( socket_send, data + send_plan[i].first, send_plan[i].second - send_plan[i].first); if (2 * i < send_plan.size()) { recvs[a] = comm_.recv( socket_recv, recv_buffers[i % ALL_SUM_BUFFERS], recv_plan[i].second - recv_plan[i].first); } else { recvs[a] = comm_.recv( socket_recv, data + recv_plan[i].first, recv_plan[i].second - recv_plan[i].first); } if (j >= 0) { sends[b].wait(); recvs[b].wait(); if (2 * j < send_plan.size()) { reduce_op( recv_buffers[j % ALL_SUM_BUFFERS], data + recv_plan[j].first, recv_plan[j].second - recv_plan[j].first); } } std::swap(a, b); } sends[b].wait(); recvs[b].wait(); } void all_gather_impl( const char* input, char* output, size_t input_size, size_t data_size, int socket_right, int socket_left, int direction) { // Choose which socket we send to and recv from int socket_send = (direction < 0) ? socket_right : socket_left; int socket_recv = (direction < 0) ? socket_left : socket_right; // Initial segments int send_segment = rank_; int recv_segment = (rank_ + direction + size_) % size_; // Copy our own segment in the output std::memcpy(output + rank_ * input_size, input, data_size); // Simple send/recv all gather. Possible performance improvement by // splitting to multiple chunks and allowing send/recv to run a bit ahead. // See all_sum_impl for an example. for (int i = 0; i < size_ - 1; i++) { auto sent = comm_.send( socket_send, output + send_segment * input_size, data_size); auto recvd = comm_.recv( socket_recv, output + recv_segment * input_size, data_size); send_segment = (send_segment + size_ + direction) % size_; recv_segment = (recv_segment + size_ + direction) % size_; sent.wait(); recvd.wait(); } } void send(const std::vector& sockets, const char* data, size_t data_size) { size_t segment_size = std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { break; } sends.emplace_back(comm_.send( sockets[i], data + i * segment_size, std::min(data_size, (i + 1) * segment_size) - i * segment_size)); } for (auto& f : sends) { f.wait(); } } void recv(const std::vector& sockets, char* data, size_t data_size) { size_t segment_size = std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> recvs; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { break; } recvs.emplace_back(comm_.recv( sockets[i], data + i * segment_size, std::min(data_size, (i + 1) * segment_size) - i * segment_size)); } for (auto& f : recvs) { f.wait(); } } int rank_; int size_; bool verbose_; ThreadPool pool_; CommunicationThreads comm_; std::vector sockets_right_; std::vector sockets_left_; std::vector buffers_; }; bool is_available() { return true; } std::shared_ptr init(bool strict /* = false */) { const char* hostfile = std::getenv("MLX_HOSTFILE"); const char* rank_str = std::getenv("MLX_RANK"); const char* ring_verbose = std::getenv("MLX_RING_VERBOSE"); if (!hostfile || !rank_str) { if (strict) { std::ostringstream msg; msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) " << "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\"" << ((hostfile) ? hostfile : "") << "\""; throw std::runtime_error(msg.str()); } return nullptr; } auto nodes = load_nodes(hostfile); int rank = std::atoi(rank_str); return std::make_shared(rank, nodes, ring_verbose != nullptr); } } // namespace mlx::core::distributed::ring ================================================ FILE: mlx/distributed/ring/ring.h ================================================ // Copyright © 2024 Apple Inc. #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::ring { using GroupImpl = mlx::core::distributed::detail::GroupImpl; bool is_available(); std::shared_ptr init(bool strict = false); } // namespace mlx::core::distributed::ring ================================================ FILE: mlx/distributed/utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include #include #include #include #include #include "mlx/distributed/utils.h" namespace mlx::core::distributed::detail { /** * Parse a sockaddr from an ip and port provided as strings. */ address_t parse_address(const std::string& ip, const std::string& port) { struct addrinfo hints, *res; std::memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); if (status != 0) { std::ostringstream msg; msg << "Can't parse address " << ip << ":" << port; throw std::runtime_error(msg.str()); } address_t result; memcpy(&result.addr, res->ai_addr, res->ai_addrlen); result.len = res->ai_addrlen; freeaddrinfo(res); return result; } /** * Parse a sockaddr provided as an : string. */ address_t parse_address(const std::string& ip_port) { auto colon = ip_port.find(":"); if (colon == std::string::npos) { std::ostringstream msg; msg << "Can't parse address " << ip_port; throw std::runtime_error(msg.str()); } std::string ip(ip_port.begin(), ip_port.begin() + colon); std::string port(ip_port.begin() + colon + 1, ip_port.end()); return parse_address(ip, port); } TCPSocket::TCPSocket(const char* tag) { sock_ = socket(AF_INET, SOCK_STREAM, 0); if (sock_ < 0) { std::ostringstream msg; msg << tag << " Couldn't create socket (error: " << errno << ")"; throw std::runtime_error(msg.str()); } } TCPSocket::TCPSocket(TCPSocket&& s) { sock_ = s.sock_; s.sock_ = -1; } TCPSocket& TCPSocket::operator=(TCPSocket&& s) { if (this != &s) { sock_ = s.sock_; s.sock_ = -1; } return *this; } TCPSocket::TCPSocket(int s) : sock_(s) {} TCPSocket::~TCPSocket() { if (sock_ > 0) { shutdown(sock_, 2); close(sock_); } } int TCPSocket::detach() { int s = sock_; sock_ = -1; return s; } void TCPSocket::listen(const char* tag, const address_t& addr) { int success; // Make sure we can launch immediately after shutdown by setting the // reuseaddr option so that we don't get address already in use errors int enable = 1; success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); if (success < 0) { std::ostringstream msg; msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")"; throw std::runtime_error(msg.str()); } success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); if (success < 0) { std::ostringstream msg; msg << tag << " Couldn't enable reuseport (error: " << errno << ")"; throw std::runtime_error(msg.str()); } // Bind the socket to the address and port success = bind(sock_, addr.get(), addr.len); if (success < 0) { std::ostringstream msg; msg << tag << " Couldn't bind socket (error: " << errno << ")"; throw std::runtime_error(msg.str()); } // Prepare waiting for connections success = ::listen(sock_, 0); if (success < 0) { std::ostringstream msg; msg << tag << " Couldn't listen (error: " << errno << ")"; throw std::runtime_error(msg.str()); } } TCPSocket TCPSocket::accept(const char* tag) { int peer = ::accept(sock_, nullptr, nullptr); if (peer < 0) { std::ostringstream msg; msg << tag << " Accept failed (error: " << errno << ")"; throw std::runtime_error(msg.str()); } return TCPSocket(peer); } void TCPSocket::send(const char* tag, const void* data, size_t len) { while (len > 0) { auto n = ::send(sock_, data, len, 0); if (n <= 0) { std::ostringstream msg; msg << tag << " Send failed with errno=" << errno; throw std::runtime_error(msg.str()); } len -= n; data = static_cast(data) + n; } } void TCPSocket::recv(const char* tag, void* data, size_t len) { while (len > 0) { auto n = ::recv(sock_, data, len, 0); if (n <= 0) { std::ostringstream msg; msg << tag << " Recv failed with errno=" << errno; throw std::runtime_error(msg.str()); } len -= n; data = static_cast(data) + n; } } TCPSocket TCPSocket::connect( const char* tag, const address_t& addr, int num_retries, int wait, std::function cb) { int sock, success; // Attempt to connect `num_retries` times with exponential backoff. for (int attempt = 0; attempt < num_retries; attempt++) { // Create the socket sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { std::ostringstream msg; msg << tag << " Couldn't create socket to connect (error: " << errno << ")"; throw std::runtime_error(msg.str()); } success = ::connect(sock, addr.get(), addr.len); if (success == 0) { break; } if (cb != nullptr) { cb(attempt, wait); } if (wait > 0) { std::this_thread::sleep_for(std::chrono::milliseconds(wait)); } wait <<= 1; } if (success < 0) { std::ostringstream msg; msg << tag << " Couldn't connect (error: " << errno << ")"; throw std::runtime_error(msg.str()); } return TCPSocket(sock); } } // namespace mlx::core::distributed::detail ================================================ FILE: mlx/distributed/utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include #include namespace mlx::core::distributed::detail { struct address_t { sockaddr_storage addr; socklen_t len; const sockaddr* get() const { return (struct sockaddr*)&addr; } }; /** * Parse a sockaddr from an ip and port provided as strings. */ address_t parse_address(const std::string& ip, const std::string& port); /** * Parse a sockaddr provided as an : string. */ address_t parse_address(const std::string& ip_port); /** * Small wrapper over a TCP socket to simplify initiating connections. */ class TCPSocket { public: TCPSocket(const char* tag); TCPSocket(const TCPSocket&) = delete; TCPSocket& operator=(const TCPSocket&) = delete; TCPSocket(TCPSocket&& s); TCPSocket& operator=(TCPSocket&&); ~TCPSocket(); void listen(const char* tag, const address_t& addr); TCPSocket accept(const char* tag); void send(const char* tag, const void* data, size_t len); void recv(const char* tag, void* data, size_t len); int detach(); operator int() const { return sock_; } static TCPSocket connect( const char* tag, const address_t& addr, int num_retries = 1, int wait = 0, std::function cb = nullptr); private: TCPSocket(int sock); int sock_; }; } // namespace mlx::core::distributed::detail ================================================ FILE: mlx/dtype.cpp ================================================ // Copyright © 2023-2024 Apple Inc. #include #include "mlx/dtype.h" namespace mlx::core { namespace { constexpr int num_types = 14; constexpr int num_cats = 8; constexpr Dtype::Kind type_kinds[num_types] = { Dtype::Kind::b, // bool_, Dtype::Kind::u, // uint8, Dtype::Kind::u, // uint16, Dtype::Kind::u, // uint32, Dtype::Kind::u, // uint64, Dtype::Kind::i, // int8, Dtype::Kind::i, // int16, Dtype::Kind::i, // int32, Dtype::Kind::i, // int64, Dtype::Kind::f, // float16, Dtype::Kind::f, // float32, Dtype::Kind::f, // float64, Dtype::Kind::V, // bfloat16, Dtype::Kind::c // complex64, }; // Following Jax type promotion rules: // https://jax.readthedocs.io/en/latest/type_promotion.html // clang-format off constexpr Dtype type_rules[num_types][num_types] = { // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64 {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8 {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16 {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32 {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64 {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8 {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16 {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32 {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64 {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16 {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32 {float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, complex64}, // float64 {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16 {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64 }; constexpr bool subcategory_to_category[num_cats][num_cats] = { // complexfloating floating inexact signedinteger unsignedinteger integer number generic {true, false, true, false, false, false, true, true}, // complexfloating {false, true, true, false, false, false, true, true}, // floating {false, false, true, false, false, false, true, true}, // inexact {false, false, false, true, false, true, true, true}, // signedinteger {false, false, false, false, true, true, true, true}, // unsignedinteger {false, false, false, false, false, true, true, true}, // integer {false, false, false, false, false, false, true, true}, // number {false, false, false, false, false, false, false, true}, // generic }; constexpr Dtype::Category type_to_category[num_types] = { Dtype::Category::generic, // bool_, Dtype::Category::unsignedinteger, // uint8, Dtype::Category::unsignedinteger, // uint16, Dtype::Category::unsignedinteger, // uint32, Dtype::Category::unsignedinteger, // uint64, Dtype::Category::signedinteger, // int8, Dtype::Category::signedinteger, // int16, Dtype::Category::signedinteger, // int32, Dtype::Category::signedinteger, // int64, Dtype::Category::floating, // float16, Dtype::Category::floating, // float32, Dtype::Category::floating, // float64, Dtype::Category::floating, // bfloat16, Dtype::Category::complexfloating, // complex64, }; // clang-format on } // namespace Dtype promote_types(const Dtype& t1, const Dtype& t2) { return Dtype( type_rules[static_cast(t1.val())][static_cast(t2.val())]); } Dtype::Kind kindof(const Dtype& t) { return type_kinds[static_cast(t.val())]; } template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template class MLX_API TypeToDtype; template <> TypeToDtype::operator Dtype() { return bool_; } template <> TypeToDtype::operator Dtype() { return uint8; } template <> TypeToDtype::operator Dtype() { return uint16; } template <> TypeToDtype::operator Dtype() { return uint32; } template <> TypeToDtype::operator Dtype() { return uint64; } template <> TypeToDtype::operator Dtype() { return int8; } template <> TypeToDtype::operator Dtype() { return int16; } template <> TypeToDtype::operator Dtype() { return int32; } template <> TypeToDtype::operator Dtype() { return int64; } template <> TypeToDtype::operator Dtype() { return float16; } template <> TypeToDtype::operator Dtype() { return float32; } template <> TypeToDtype::operator Dtype() { return float32; } template <> TypeToDtype::operator Dtype() { return bfloat16; } template <> TypeToDtype::operator Dtype() { return complex64; } bool issubdtype(const Dtype& a, const Dtype& b) { return a == b; } bool issubdtype(const Dtype::Category& cat, const Dtype& type) { return false; } bool issubdtype(const Dtype& type, const Dtype::Category& cat) { return issubdtype(type_to_category[static_cast(type.val())], cat); } bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { return subcategory_to_category[static_cast(a)] [static_cast(b)]; } } // namespace mlx::core ================================================ FILE: mlx/dtype.h ================================================ // Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/api.h" #include "mlx/types/complex.h" #include "mlx/types/half_types.h" namespace mlx::core { struct Dtype { enum class Val { bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64, }; enum class Kind { b, /* bool */ u, /* unsigned int */ i, /* signed int */ f, /* float */ c, /* complex */ V, /* void - used for brain float */ }; enum class Category { complexfloating, floating, inexact, signedinteger, unsignedinteger, integer, number, generic }; constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {} constexpr operator Val() const { return val_; } constexpr Val val() const { return val_; } constexpr uint8_t size() const { return size_; } private: Val val_; uint8_t size_; }; inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)}; inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; inline constexpr Dtype::Category complexfloating = Dtype::Category::complexfloating; inline constexpr Dtype::Category floating = Dtype::Category::floating; inline constexpr Dtype::Category inexact = Dtype::Category::inexact; inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger; inline constexpr Dtype::Category unsignedinteger = Dtype::Category::unsignedinteger; inline constexpr Dtype::Category integer = Dtype::Category::integer; inline constexpr Dtype::Category number = Dtype::Category::number; inline constexpr Dtype::Category generic = Dtype::Category::generic; MLX_API bool issubdtype(const Dtype& a, const Dtype& b); MLX_API bool issubdtype(const Dtype::Category& a, const Dtype& b); MLX_API bool issubdtype(const Dtype& a, const Dtype::Category& b); MLX_API bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); MLX_API Dtype promote_types(const Dtype& t1, const Dtype& t2); inline uint8_t size_of(const Dtype& t) { return t.size(); } MLX_API Dtype::Kind kindof(const Dtype& t); template struct MLX_API TypeToDtype { operator Dtype(); }; } // namespace mlx::core ================================================ FILE: mlx/dtype_utils.cpp ================================================ // Copyright © 2025 Apple Inc. #include "mlx/dtype_utils.h" namespace mlx::core { const char* dtype_to_string(Dtype arg) { switch (arg) { case bool_: return "bool"; case int8: return "int8"; case int16: return "int16"; case int32: return "int32"; case int64: return "int64"; case uint8: return "uint8"; case uint16: return "uint16"; case uint32: return "uint32"; case uint64: return "uint64"; case float16: return "float16"; case bfloat16: return "bfloat16"; case float32: return "float32"; case float64: return "float64"; case complex64: return "complex64"; default: return "unknown"; } } } // namespace mlx::core ================================================ FILE: mlx/dtype_utils.h ================================================ // Copyright © 2025 Apple Inc. #pragma once #include #include "mlx/dtype.h" #include "mlx/utils.h" namespace mlx::core { // Return string representation of dtype. const char* dtype_to_string(Dtype arg); #define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ case DTYPE: \ f(type_identity{}); \ break #define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) #define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) // This already exists in C++20 but in C++20 we can also just use templated // lambdas which will make this so much nicer. template struct type_identity { using type = T; }; #define MLX_GET_TYPE(x) typename decltype(x)::type #define MLX_GET_VALUE(x) decltype(x)::value template void dispatch_all_types(Dtype dt, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); MLX_INTERNAL_DTYPE_SWITCH_INTS(); MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); } } template void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_INTS(); default: std::ostringstream msg; msg << tag << " Only integer types supported but " << dt << " was provided"; throw std::invalid_argument(msg.str()); } } template void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); default: std::ostringstream msg; msg << tag << " Only float types supported but " << dt << " was provided"; throw std::invalid_argument(msg.str()); } } template void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); default: std::ostringstream msg; msg << tag << " Only inexact (float/complex) types supported but " << dt << " was provided"; throw std::invalid_argument(msg.str()); } } template void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_INTS(); MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); default: std::ostringstream msg; msg << tag << " Only integer and float types supported but " << dt << " was provided"; throw std::invalid_argument(msg.str()); } } template void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); MLX_INTERNAL_DTYPE_SWITCH_INTS(); MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); default: std::ostringstream msg; msg << tag << " Only real numbers supported but " << dt << " was provided"; throw std::invalid_argument(msg.str()); } } } // namespace mlx::core ================================================ FILE: mlx/einsum.cpp ================================================ // Copyright © 2024 Apple Inc. #include #include #include #include #include "mlx/einsum.h" #include "mlx/ops.h" namespace mlx::core { namespace { // The MLX einsum implementation is based on NumPy (which is based on // opt_einsum): // https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743 // https://github.com/dgasmith/opt_einsum using CharSet = std::unordered_set; // A helper struct to hold the string and set // representation of a subscript to avoid needing // to recompute the set struct Subscript { Subscript(std::string str, CharSet set) : str(std::move(str)), set(std::move(set)) {}; std::string str; CharSet set; }; struct PathInfo { size_t naive_cost; size_t naive_scaling; size_t optimized_cost; size_t optimized_scaling; size_t largest_term; }; struct PathNode { PathNode( std::vector inputs, Subscript output, std::vector positions) : inputs(std::move(inputs)), output(std::move(output)), positions(std::move(positions)) {}; std::vector inputs; Subscript output; std::vector positions; }; // Parse the comma separated subscripts into a vector of strings. If the // output subscripts are missing they are inferred. // // For example: // "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"} // "ij,jk" becomes {{"ij", "jk"}, "ik"} std::pair, std::string> parse(std::string subscripts) { std::string lhs, rhs; // Start by removing all white space subscripts.erase( std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end()); if (auto pos = subscripts.find("->"); pos != std::string::npos) { // Explicit mode lhs = subscripts.substr(0, pos); rhs = subscripts.substr(pos + 2); } else { // Implicit mode: // - repeats are summed // - ellipses are placed in the beginning of the output // - remaining output axes are ordered alphabetically lhs = subscripts; std::unordered_map temp; for (auto& c : subscripts) { if (c == ',') { continue; } if (c == '.' && rhs.empty()) { rhs += "..."; continue; } auto inserted = temp.insert({c, 0}); inserted.first->second++; } for (auto& k : temp) { if (k.second == 1) { rhs += k.first; } } std::sort(rhs.begin(), rhs.end()); } std::vector input_list; std::stringstream ss(lhs); std::string token; while (getline(ss, token, ',')) { input_list.push_back(token); } return {input_list, rhs}; } // Check if two sets are disjoint bool disjoint(const CharSet& x, const CharSet& y) { for (auto& c : x) { if (y.find(c) != y.end()) { return false; } } return true; } template size_t term_size(const T& term, std::unordered_map dict) { size_t size = 1; for (auto c : term) { size *= dict[c]; } return size; } size_t flop_count( const CharSet& term, bool inner, int num_terms, std::unordered_map dict) { size_t size = term_size(term, dict); auto op_factor = 1; if ((num_terms - 1) > op_factor) { op_factor = num_terms - 1; } if (inner) { op_factor += 1; } return size * op_factor; } std::pair compute_cost_and_scaling( const std::vector& inputs, const Subscript& output, std::unordered_map dim_map) { CharSet contractions; for (auto& in : inputs) { contractions.insert(in.set.begin(), in.set.end()); } bool inner = false; for (auto c : contractions) { if (output.set.find(c) == output.set.end()) { inner = true; break; } } auto cost = flop_count(contractions, inner, inputs.size(), dim_map); return {cost, contractions.size()}; } std::tuple, size_t, int> greedy_path( std::vector inputs, const Subscript& output, std::unordered_map dim_map, size_t cost_limit, size_t memory_limit) { // Helper struct for building the greedy path struct Contraction { Contraction( size_t size, size_t cost, CharSet output, int dims, int x, int y) : size(size), cost(cost), output(std::move(output)), dims(dims), x(x), y(y) {}; int64_t size; // Size difference, can be negative size_t cost; CharSet output; int dims; // Number of dimensions in the contraction int x; int y; }; // Start by iterating over all possible combinations std::vector> pos_pairs; for (int i = 0; i < inputs.size(); ++i) { for (int j = i + 1; j < inputs.size(); ++j) { pos_pairs.emplace_back(i, j); } } std::vector path; std::vector possible_contractions; size_t path_cost = 0; int path_scaling = 0; auto num_in = inputs.size(); for (int i = 0; i < num_in - 1; ++i) { auto add_contraction = [&](int p1, int p2) { CharSet new_term; CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end()); contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end()); for (int i = 0; i < inputs.size(); i++) { if (i == p1 || i == p2) { continue; } auto& in = inputs[i].set; for (auto c : in) { if (contractions.find(c) != contractions.end()) { new_term.insert(c); } } } for (auto c : output.set) { if (contractions.find(c) != contractions.end()) { new_term.insert(c); } } // Ignore if: // - The size of the new result is greater than the memory limit // - The cost is larger than the naive cost auto new_size = term_size(new_term, dim_map); if (new_size > memory_limit) { return; } int64_t removed_size = term_size(inputs[p1].set, dim_map) + term_size(inputs[p2].set, dim_map) - new_size; bool inner = contractions.size() > new_term.size(); auto cost = flop_count(contractions, inner, 2, dim_map); if (path_cost + cost > cost_limit) { return; } possible_contractions.emplace_back( removed_size, cost, std::move(new_term), contractions.size(), p1, p2); }; for (auto& [p1, p2] : pos_pairs) { // Ignore outer products if (!disjoint(inputs[p1].set, inputs[p2].set)) { add_contraction(p1, p2); } } // If there's nothing in the contraction list, // go over the pairs again without ignoring outer products if (possible_contractions.empty()) { for (auto& [p1, p2] : pos_pairs) { add_contraction(p1, p2); } } if (possible_contractions.empty()) { // Default to naive einsum for the remaining inputs std::vector positions(inputs.size()); std::iota(positions.begin(), positions.end(), 0); auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map); path.emplace_back(std::move(inputs), output, std::move(positions)); path_cost += cost; path_scaling = std::max(scale, path_scaling); break; } // Find the best contraction auto& best = *std::min_element( possible_contractions.begin(), possible_contractions.end(), [](const auto& x, const auto& y) { return x.size > y.size || (x.size == y.size && x.cost < y.cost); }); path_scaling = std::max(best.dims, path_scaling); // Construct the output subscripts std::string out_str(best.output.begin(), best.output.end()); // TODO, sorting by dimension size seems suboptimal? std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) { return dim_map[x] < dim_map[y]; }); Subscript new_output(std::move(out_str), std::move(best.output)); // Add the chosen contraction to the path { std::vector in_terms; in_terms.push_back(std::move(inputs[best.x])); in_terms.push_back(std::move(inputs[best.y])); path.emplace_back( std::move(in_terms), new_output, std::vector{best.x, best.y}); } // Remove used terms inputs.erase(inputs.begin() + best.y); inputs.erase(inputs.begin() + best.x); // Add the new result inputs.push_back(std::move(new_output)); // Update the existing contractions based on the selected one std::vector updated_contractions; for (auto& contraction : possible_contractions) { // Drop contractions which contain either selected term if (contraction.x == best.x || contraction.x == best.y || contraction.y == best.x || contraction.y == best.y) { continue; } // Update the positions of other contractions int x = contraction.x - (contraction.x > best.x) - (contraction.x > best.y); int y = contraction.y - (contraction.y > best.x) - (contraction.y > best.y); contraction.x = x; contraction.y = y; updated_contractions.push_back(std::move(contraction)); } pos_pairs.clear(); for (int i = 0; i < inputs.size() - 1; ++i) { pos_pairs.emplace_back(i, inputs.size() - 1); } path_cost += best.cost; possible_contractions = std::move(updated_contractions); } return {path, path_cost, path_scaling}; } // Assumes inputs have already have had repeats and single axis sums collapsed bool can_dot(const std::vector& inputs, const Subscript& output) { if (inputs.size() != 2) { return false; } for (auto c : inputs[0].set) { // Use batched tensordot if anything is being contracted if (output.set.find(c) == output.set.end()) { return true; } } return false; } array batch_tensordot( array a, array b, std::vector a_contract, std::vector a_batch, std::vector a_concat, std::vector b_contract, std::vector b_batch, std::vector b_concat, StreamOrDevice s) { // Broadcast contracting dimensions { auto a_shape = a.shape(); auto b_shape = b.shape(); for (int i = 0; i < a_contract.size(); ++i) { auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i])); a_shape[a_contract[i]] = d; b_shape[b_contract[i]] = d; } a = broadcast_to(a, a_shape, s); b = broadcast_to(b, b_shape, s); } auto transpose_reshape = [&s]( const array& x, const std::vector& i, const std::vector& j, const std::vector& k) { std::vector reorder(i.begin(), i.end()); reorder.insert(reorder.end(), j.begin(), j.end()); reorder.insert(reorder.end(), k.begin(), k.end()); int size1 = 1; for (auto s : j) { size1 *= x.shape(s); } int size2 = 1; for (auto s : k) { size2 *= x.shape(s); } Shape shape; for (auto ax : i) { shape.push_back(x.shape(ax)); } shape.push_back(size1); shape.push_back(size2); return reshape(transpose(x, reorder, s), std::move(shape), s); }; Shape out_shape; for (auto ax : a_batch) { out_shape.push_back(a.shape(ax)); } for (auto ax : a_concat) { out_shape.push_back(a.shape(ax)); } for (auto ax : b_concat) { out_shape.push_back(b.shape(ax)); } a = transpose_reshape(a, a_batch, a_concat, a_contract); b = transpose_reshape(b, b_batch, b_contract, b_concat); return reshape(matmul(a, b, s), std::move(out_shape), s); } // Collapse repeated subscripts and return the resulting array. The subscript // is also updated in place. For example: // - Given an input with shape (4, 4) and subscript "ii", returns // the diagonal of shape (4,) and updates the subscript to "i". // - Given an input with shape (4, 2, 4, 2) and subscript "ijij", // returns an output with shape (4, 2) and updates the subscript // to "ij". array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { // Build a list of (repeat chars, num repeats) auto& str = subscript.str; std::vector> repeats; std::string new_str; { std::string repeat_str; std::string no_repeat_str; std::unordered_map counts; for (int i = 0; i < str.size(); ++i) { auto [it, _] = counts.insert({str[i], 0}); it->second++; } for (auto& v : counts) { if (v.second > 1) { repeats.emplace_back(v.first, v.second); repeat_str += v.first; } } for (auto& c : str) { if (counts[c] == 1) { no_repeat_str += c; } } new_str = repeat_str + no_repeat_str; } // Build the inputs for gather auto slice_sizes = in.shape(); std::vector axes; std::vector indices; int n_expand = repeats.size(); for (auto [c, v] : repeats) { for (int i = 0; i < str.size(); ++i) { if (str[i] == c) { slice_sizes[i] = 1; axes.push_back(i); } } Shape idx_shape(n_expand--, 1); idx_shape[0] = in.shape(axes.back()); auto idx = reshape( arange(static_cast(in.shape(axes.back())), s), idx_shape, s); for (int i = 0; i < v; ++i) { indices.push_back(idx); } } in = gather(in, indices, axes, slice_sizes, s); // Update subscript string with removed dups str = new_str; // Squeeze singleton dimensions left over from the gather for (auto& ax : axes) { ax += indices[0].ndim(); } return squeeze(in, axes, s); } // Collapse repeat indices and sum single dimensions. // For example: // - "aa" becomes "a" // - "ij,jk->k" becoms "j,jk->k" void preprocess_einsum_inputs( std::vector& inputs, const Subscript& output, const std::vector& positions, std::vector& operands, StreamOrDevice s) { // Collapse repeat indices for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; if (in.set.size() < in.str.size()) { operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s); } } // Sum indices that are only in a single input { std::unordered_map counts; for (auto& in : inputs) { for (auto c : in.set) { auto inserted = counts.insert({c, 0}); inserted.first->second++; } } for (auto c : output.set) { auto inserted = counts.insert({c, 0}); inserted.first->second++; } for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; std::vector sum_axes; for (int ax = 0; ax < in.str.size(); ++ax) { if (counts[in.str[ax]] == 1) { sum_axes.push_back(ax); } } if (!sum_axes.empty()) { operands[positions[i]] = sum(operands[positions[i]], sum_axes, false, s); } for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) { in.set.erase(in.str[*it]); in.str.erase(in.str.begin() + *it); } } } } array einsum_naive( std::vector inputs, const Subscript& output, const std::vector& positions, std::vector operands, StreamOrDevice s) { // Map each character to an axis std::unordered_map char_to_ax; for (auto& in : inputs) { for (auto c : in.str) { char_to_ax.insert({c, char_to_ax.size()}); } } // Expand and transpose inputs as needed for (int i = 0; i < inputs.size(); ++i) { int pos = positions[i]; auto& op = operands[pos]; // Add missing dimensions at the end if (op.ndim() != char_to_ax.size()) { auto shape = op.shape(); shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1); op = reshape(op, std::move(shape), s); } // Transpose: // - Build a vector of (char, ax) pairs for the current input // - Sort the vector by the canonical axis in char_to_ax // - Extract the sorted axis to get transpose order std::vector> str_ax; for (auto c : inputs[i].str) { str_ax.emplace_back(c, str_ax.size()); } for (auto [c, ax] : char_to_ax) { if (inputs[i].set.find(c) == inputs[i].set.end()) { str_ax.emplace_back(c, str_ax.size()); } } std::sort( str_ax.begin(), str_ax.end(), [&char_to_ax](const auto& x, const auto& y) { return char_to_ax[x.first] < char_to_ax[y.first]; }); // Skip the transpose if not needed if (std::is_sorted( str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) { return x.second < y.second; })) { continue; } std::vector reorder; for (auto [c, ax] : str_ax) { reorder.push_back(ax); } op = transpose(op, reorder, s); } // Multiply and sum auto out = operands[positions[0]]; for (int i = 1; i < positions.size(); ++i) { out = multiply(out, operands[positions[i]], s); } std::vector sum_axes; for (auto [c, ax] : char_to_ax) { if (output.set.find(c) == output.set.end()) { sum_axes.push_back(ax); } } if (!sum_axes.empty()) { out = sum(out, sum_axes, false, s); } // Transpose output if needed std::vector reorder; for (auto c : output.str) { reorder.push_back(char_to_ax[c]); } for (auto& r : reorder) { int offset = 0; for (auto s : sum_axes) { if (r > s) { offset++; } } r -= offset; } return transpose(out, reorder, s); } std::pair, PathInfo> einsum_path_helper( const std::string& subscripts, const std::vector& operands, const std::string& fn_name) { if (operands.size() == 0) { std::ostringstream msg; msg << "[" << fn_name << "] At least one operand is required."; throw std::invalid_argument(msg.str()); } auto [in_subscripts, out_subscript] = parse(subscripts); if (operands.size() != in_subscripts.size()) { std::ostringstream msg; msg << "[" << fn_name << "] Number of operands, " << operands.size() << ", does not match number of input subscripts, " << in_subscripts.size(); throw std::invalid_argument(msg.str()); } // Expand ellipses // 1. Collect all the characters we can use for the missing axes. // 2. Go over each subscript and check if all the characters are either // alphanumeric or an ellipsis. // 3. Expand the ellipsis with as many characters from the unused ones as // necessary. We use the last N characters effectively prepending with // singleton dims for inputs with fewer dimensions. // 4. For the output use the maximum size of ellipsis that we encountered in // the input. CharSet used_chars(subscripts.begin(), subscripts.end()); std::string remaining_chars; remaining_chars.reserve(52 - used_chars.size()); for (char c = 'a'; c <= 'z'; c++) { if (used_chars.find(c) == used_chars.end()) { remaining_chars += c; } } for (char c = 'A'; c <= 'Z'; c++) { if (used_chars.find(c) == used_chars.end()) { remaining_chars += c; } } int max_ellipsis_length = 0; auto check_letters_and_expand_ellipsis = [&](auto& subscript, const array* operand, int operand_idx) { bool have_ellipsis = false; int cnt_before = 0, cnt_after = 0; for (int i = 0; i < subscript.size(); i++) { if (!isalpha(subscript[i])) { if (i + 2 >= subscript.size() || subscript[i] != '.' || subscript[i + 1] != '.' || subscript[i + 2] != '.') { std::ostringstream msg; msg << "[" << fn_name << "] Subscripts must be letters, but got '" << subscript[i] << "'."; throw std::invalid_argument(msg.str()); } if (have_ellipsis) { std::ostringstream msg; msg << "[" << fn_name << "] Only one ellipsis per subscript is allowed but found more in '" << subscript << "'."; throw std::invalid_argument(msg.str()); } have_ellipsis = true; i += 2; continue; } if (have_ellipsis) { cnt_after++; } else { cnt_before++; } } if (have_ellipsis) { int ellipsis_length; if (operand != nullptr) { ellipsis_length = operand->ndim() - cnt_before - cnt_after; if (ellipsis_length < 0) { std::ostringstream msg; msg << "[" << fn_name << "] Operand " << operand_idx << " with shape " << operand->shape() << " has insufficient dimensions for subscript '" << subscript << "'. The ellipsis requires at least " << (cnt_before + cnt_after) << " dimensions but the operand has " << operand->ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length); } else { ellipsis_length = max_ellipsis_length; } subscript.replace( subscript.begin() + cnt_before, subscript.begin() + cnt_before + 3, remaining_chars.end() - ellipsis_length, remaining_chars.end()); } }; for (int i = 0; i < operands.size(); i++) { check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i); } check_letters_and_expand_ellipsis(out_subscript, nullptr, -1); CharSet out_set(out_subscript.begin(), out_subscript.end()); if (out_set.size() != out_subscript.size()) { std::ostringstream msg; msg << "[" << fn_name << "] Repeat indices not allowed in output."; throw std::invalid_argument(msg.str()); } Subscript output(out_subscript, std::move(out_set)); std::unordered_map dim_map; std::vector inputs; for (int i = 0; i < in_subscripts.size(); ++i) { auto& in = in_subscripts[i]; CharSet in_set(in.begin(), in.end()); inputs.emplace_back(in, in_set); if (in.size() != operands[i].ndim()) { std::ostringstream msg; msg << "[" << fn_name << "] Invalid number of subscripts " << in.size() << " for input " << i << " with " << operands[i].ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } // Check repeat subscripts are valid if (in_set.size() < in.size()) { std::unordered_map local_dims; for (int j = 0; j < in.size(); ++j) { auto dim = operands[i].shape(j); auto inserted = local_dims.insert({in[j], dim}); if (!inserted.second) { if (inserted.first->second != dim) { std::ostringstream msg; msg << "[" << fn_name << "] Dimensions of repeated subscripts " << "do not have the same size (" << inserted.first->second << " != " << dim << ")."; throw std::invalid_argument(msg.str()); } } } } for (int j = 0; j < in.size(); j++) { auto c = in[j]; auto dim = operands[i].shape(j); auto inserted = dim_map.insert({c, dim}); auto& in_dim = inserted.first->second; if (dim != 1 && in_dim != 1 && in_dim != dim) { std::ostringstream msg; msg << "[" << fn_name << "] Cannot broadcast dimension " << j << " of input " << i << " with shape " << operands[i].shape() << " to size " << in_dim << "."; throw std::invalid_argument(msg.str()); } // Ensure the broadcasted size is used in_dim = std::max(in_dim, dim); } } size_t max_size = term_size(out_subscript, dim_map); for (auto& in : in_subscripts) { max_size = std::max(max_size, term_size(in, dim_map)); } PathInfo path_info{}; // Get the full naive cost std::tie(path_info.naive_cost, path_info.naive_scaling) = compute_cost_and_scaling(inputs, output, dim_map); // Calculate the path std::vector path; if (inputs.size() <= 2) { std::vector positions(in_subscripts.size()); std::iota(positions.begin(), positions.end(), 0); path.emplace_back( std::move(inputs), std::move(output), std::move(positions)); path_info.optimized_cost = path_info.naive_cost; path_info.optimized_scaling = path_info.naive_scaling; } else { std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) = greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size); // Set the final output subscript to the actual output path.back().output = std::move(output); } return {path, path_info}; } } // namespace std::pair>, std::string> einsum_path( const std::string& subscripts, const std::vector& operands) { auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum_path"); std::vector> pos_path; for (auto& p : path) { pos_path.push_back(p.positions); } std::ostringstream path_print; path_print << " Complete contraction: " << subscripts << "\n" << " Naive scaling: " << path_info.naive_scaling << "\n" << " Optimized scaling: " << path_info.optimized_scaling << "\n" << " Naive FLOP count: " << path_info.naive_cost << "\n" << " Optimized FLOP count: " << path_info.optimized_cost << "\n"; // TODO add more info here return {pos_path, path_print.str()}; } array einsum( const std::string& subscripts, const std::vector& operands, StreamOrDevice s /* = {} */) { auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum"); auto inputs = operands; for (auto& node : path) { preprocess_einsum_inputs( node.inputs, node.output, node.positions, inputs, s); if (can_dot(node.inputs, node.output)) { auto& in_a = node.inputs[0]; auto& in_b = node.inputs[1]; auto& out = node.output; std::vector a_contract; std::vector a_batch; std::vector a_concat; for (int i = 0; i < in_a.str.size(); ++i) { auto c = in_a.str[i]; if (out.set.find(c) == out.set.end()) { // Not in the output, contraction a_contract.push_back(i); } else if (in_b.set.find(c) != in_b.set.end()) { // Not a contraction but in both inputs, batch dim a_batch.push_back(i); } else { // Not a batch dim or contract dim, so concat dim a_concat.push_back(i); } } std::vector b_contract; std::vector b_batch; std::vector b_concat; for (auto a_i : a_contract) { b_contract.push_back(in_b.str.find(in_a.str[a_i])); } for (auto a_i : a_batch) { b_batch.push_back(in_b.str.find(in_a.str[a_i])); } for (int i = 0; i < in_b.str.size(); ++i) { auto c = in_b.str[i]; if (out.set.find(c) != out.set.end() && in_a.set.find(c) == in_a.set.end()) { b_concat.push_back(i); } } auto& a = inputs[node.positions[0]]; auto& b = inputs[node.positions[1]]; std::unordered_map char_map; for (auto i : a_batch) { char_map.insert({in_a.str[i], char_map.size()}); } for (auto i : a_concat) { char_map.insert({in_a.str[i], char_map.size()}); } for (auto i : b_concat) { char_map.insert({in_b.str[i], char_map.size()}); } inputs.emplace_back(batch_tensordot( a, b, std::move(a_contract), std::move(a_batch), std::move(a_concat), std::move(b_contract), std::move(b_batch), std::move(b_concat), s)); std::vector reorder; for (auto c : node.output.str) { reorder.push_back(char_map[c]); } inputs.back() = transpose(inputs.back(), reorder, s); } else { inputs.emplace_back( einsum_naive(node.inputs, node.output, node.positions, inputs, s)); } // Positions are always sorted increasing, so start from the back for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) { inputs.erase(inputs.begin() + *it); } } return inputs.front(); } } // namespace mlx::core ================================================ FILE: mlx/einsum.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/api.h" #include "mlx/array.h" #include "mlx/utils.h" namespace mlx::core { MLX_API std::pair>, std::string> einsum_path( const std::string& subscripts, const std::vector& operands); MLX_API array einsum( const std::string& subscripts, const std::vector& operands, StreamOrDevice s = {}); } // namespace mlx::core ================================================ FILE: mlx/event.h ================================================ // Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/stream.h" namespace mlx::core { class Event { public: Event() {}; explicit Event(Stream stream); // Wait for the event to be signaled at its current value void wait(); // Wait in the given stream for the event to be signaled at its current value void wait(Stream stream); // Signal the event at its current value in the given stream void signal(Stream stream); // Check if the event has been signaled at its current value bool is_signaled() const; // Check if the event is valid bool valid() const { return event_ != nullptr; } uint64_t value() const { return value_; } void set_value(uint64_t v) { value_ = v; } const Stream& stream() const { if (!valid()) { throw std::runtime_error( "[Event::stream] Cannot access stream on invalid event."); } return stream_; } private: // Default constructed stream should never be used // since the event is not yet valid Stream stream_{0, Device::cpu}; std::shared_ptr event_{nullptr}; uint64_t value_{0}; }; } // namespace mlx::core ================================================ FILE: mlx/export.cpp ================================================ // Copyright © 2024 Apple Inc. #include "mlx/export.h" #include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" #include "mlx/version.h" // clang-format off #define SERIALIZE_PRIMITIVE(primitive, ...) \ { \ #primitive, { \ serialize_primitive, \ deserialize_primitive, \ primitive_state, \ {__VA_ARGS__} \ } \ } // clang-format on bool is_big_endian() { int num = 1; return *reinterpret_cast(&num) != 1; } namespace mlx::core { using namespace mlx::core::fast; using Reader = io::ParallelFileReader; using Writer = io::FileWriter; struct PrimitiveSerializer { using Serializer = std::function; using Deserializer = std::function(Reader&, Stream s)>; using StateExtractor = std::function(const Primitive&)>; PrimitiveSerializer( Serializer serialize, Deserializer deserialize, StateExtractor extract_state, std::vector keys = {}) : serialize(std::move(serialize)), deserialize(std::move(deserialize)), extract_state(std::move(extract_state)), keys(std::move(keys)) {}; Serializer serialize; Deserializer deserialize; StateExtractor extract_state; std::vector keys; }; template constexpr bool is_iterable = false; template constexpr bool is_iterable< T, std::void_t< decltype(std::declval().begin()), decltype(std::declval().end())>> = true; template