This layer does not automatically gather all outputs from each device. This is
an intended and :ref:`useful design choice `.
:class:`QuantizedAllToShardedLinear ` is
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
included in any gradient computation.
:class:`ShardedToAllLinear `
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This layer expects inputs that are sharded along the feature dimension and
shards the weight matrix along the input dimension across all devices in the
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
group will have the same result.
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
and a device group with 2 devices. The layer shards the weight matrix along the
input dimension across the two devices. Each device computes a ``(4,2)``
output, which is then aggregated with all other device outputs to get layer
output.
.. raw:: html
This layer does not automatically shard the inputs along the feature dimension
for you. It is necessary to create a "partial" input structure to feed into the
layer. This is an intended and :ref:`useful design choice
`.
:class:`QuantizedShardedToAllLinear ` is
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
included in any gradient computation.
Shard Utility Functions
-----------------------
:func:`shard_linear `
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Converts a regular linear layer into a tensor parallel layer that distributes
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
(either :class:`mlx.nn.AllToShardedLinear` or
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
original layer is not modified.
:func:`shard_inplace `
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Splits the parameters of an existing layer across multiple devices by modifying
the layer in-place. Unlike :func:`shard_linear
`, this function does not create a new
layer or add distributed communication. The layer itself must handle
distributed communication if needed.
.. _useful_design_choices:
Useful Design Choices
---------------------
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
All-to-sharded and sharded-to-all layers naturally go together because the
output of the former layer is exactly the input needed needed for the latter.
This removes the need for an intermediate gather step between the layers,
reducing communication overhead.
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
automatically. It is so that they can be placed in successive order and work
together easily.
We can demonstrate this through a simple model using our two types of
distributed layers.
.. code-block:: python
x = ... # some (4, 2) model input: batch size 4, feature size 2
l1 = nn.AllToShardedLinear(2, 2, bias=False) # initialize the layer
l1_out = l1(x) # (4, 1) output
l2 = nn.ShardedToAllLinear(2, 2, bias=False)
l2_out = l2(l1_out) # (4, 2) output
.. raw:: html
A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.
LLM Inference with Tensor Parallelism
-------------------------------------
We can apply these TP techniques to LLMs in order to enable inference for much
larger models by sharding parameters from huge layers across multiple devices.
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
Inference ` example. In this example, we will use the same
inference script as the Llama Inference example, which can be found in
`mlx-examples`_.
Our first edit is to initialize the distributed communication group and get the
current process rank:
.. code-block:: python
world = mx.distributed.init()
rank = world.rank()
Next, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism:
.. raw:: html
This architecture has two natural places where
tensor parallelism can be applied: the attention block and the FFN
block. Both follow the same pattern: multiple parallel linear layers operating
on the same input, followed by a single output linear layer. In the attention
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
The intermediate operations between the linear layers (RoPE, softmax, scaled
dot-product attention in the attention block, and element-wise multiplication
in the FFN block) do not impede the use of our TP paradigm. These operations
are either:
- **Element-wise operations** (RoPE, element-wise multiplication): These
operate independently on each element or position, preserving the sharding
pattern without requiring cross-device communication.
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
attention): These operate along dimensions that are not sharded (such as the
sequence length or head dimensions), so they can be computed independently on
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
correctly with sharded Q, K, V tensors because the matrix multiplications are
performed along the sharded feature dimension, and the results remain
properly sharded for the subsequent sharded-to-all layer.
To implement sharding in our Llama inference, we use :func:`shard_linear
` to get sharded linear layers with
distributed communication. This is easier than using :func:`shard_inplace
` and implementing the steps manually
in the :code:`__call__` function.
The following code shows how to shard the Attention block. The Q, K, and V
projection layers are converted to all-to-sharded layers, while the output
projection is converted to a sharded-to-all layer. The number of heads are also
adjusted to account for the sharding:
.. code-block:: python
# ... in Attention class
def shard(self, group: mx.distributed.Group):
self.n_heads = self.n_heads // group.size()
self.n_kv_heads = self.n_kv_heads // group.size()
self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group)
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
(w3) projections to all-to-sharded layers, and the down projection (w2) to
a sharded-to-all layer:
.. code-block:: python
# ... in FeedForward class
def shard(self, group: mx.distributed.Group):
self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group)
self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group)
self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group)
Finally, in our :code:`load_model` function, we need to apply our sharding
functions to all transformer layers when using multiple devices:
.. code-block:: python
# ... in load_model function
if world.size() > 1:
# convert Linear layers in Transformer/FFN to appropriate Sharded Layers
for layer in model.layers:
layer.attention.shard(group=world)
layer.feed_forward.shard(group=world)
This allows us to use the llama inference file as normal when running
:code:`python llama.py`, but now we can also run it across two (or more)
devices via :code:`mlx.launch -n 2 llama.py`.
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
================================================
FILE: docs/src/index.rst
================================================
MLX
===
MLX is a NumPy-like array framework designed for efficient and flexible machine
learning on Apple silicon, brought to you by Apple machine learning research.
The Python API closely follows NumPy with a few exceptions. MLX also has a
fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is inspired by frameworks like `PyTorch
`_, `Jax `_, and
`ArrayFire `_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
.. toctree::
:caption: Install
:maxdepth: 1
install
.. toctree::
:caption: Usage
:maxdepth: 1
usage/quick_start
usage/lazy_evaluation
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
:maxdepth: 1
examples/linear_regression
examples/mlp
examples/llama-inference
examples/data_parallelism
examples/tensor_parallelism
.. toctree::
:caption: Python API Reference
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
:caption: C++ API Reference
:maxdepth: 1
cpp/ops
.. toctree::
:caption: Further Reading
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/metal_logging
dev/custom_metal_kernels
dev/mlx_in_cpp
================================================
FILE: docs/src/install.rst
================================================
.. _build_and_install:
Build and Install
=================
Python Installation
-------------------
MLX is available on PyPI. All you have to do to use MLX with your own Apple
silicon computer is
.. code-block:: shell
pip install mlx
To install from PyPI your system must meet the following requirements:
- Using `Apple silicon `_
- Using a native Python >= 3.10
- macOS >= 14.0
.. note::
MLX is only available on devices running macOS >= 14.0 and higher.
CUDA
^^^^
MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.10
Troubleshooting
^^^^^^^^^^^^^^^
*My OS and Python versions are in the required range but pip still does not find
a matching distribution.*
Probably you are using a non-native Python. The output of
.. code-block:: shell
python -c "import platform; print(platform.processor())"
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
way to do this is with `Conda `_.
Build from source
-----------------
Build Requirements
^^^^^^^^^^^^^^^^^^
- ``libblas-dev``, ``liblapack-dev``, and ``liblapacke-dev`` (Linux)
- A C++ compiler with C++20 support (e.g. Clang >= 15.0)
- `cmake `_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section ` below.
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo `_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Then simply build and install MLX using pip:
.. code-block:: shell
pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
python setup.py build_ext --inplace
Run the tests with:
.. code-block:: shell
python -m unittest discover python/tests
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
by cloning MLX from `its GitHub repo
`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
Run tests with:
.. code-block:: shell
make test
Install with:
.. code-block:: shell
make install
Note that the built ``mlx.metallib`` file should be either at the same
directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
* - Option
- Default
* - MLX_BUILD_TESTS
- ON
* - MLX_BUILD_EXAMPLES
- OFF
* - MLX_BUILD_BENCHMARKS
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python ` or :ref:`C++ ` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
You see the following error when you try to build:
.. code-block:: shell
error: unable to find utility "metal", not a developer tool or in PATH
To fix this, first make sure you have Xcode installed:
.. code-block:: shell
xcode-select --install
Then set the active developer directory:
.. code-block:: shell
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
terminal.
Verify the terminal is now running natively the following command:
.. code-block:: shell
$ uname -p
arm
Also check that cmake is using the correct architecture:
.. code-block:: shell
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.
================================================
FILE: docs/src/python/array.rst
================================================
.. _array:
Array
=====
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
array.real
array.imag
array.abs
array.all
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
array.round
array.rsqrt
array.sin
array.split
array.sqrt
array.square
array.squeeze
array.std
array.sum
array.swapaxes
array.transpose
array.T
array.var
array.view
================================================
FILE: docs/src/python/cuda.rst
================================================
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available
================================================
FILE: docs/src/python/data_types.rst
================================================
.. _data_types:
Data Types
==========
.. currentmodule:: mlx.core
The default floating point type is ``float32`` and the default integer type is
``int32``. The table below shows supported values for :obj:`Dtype`.
.. list-table:: Supported Data Types
:widths: 5 3 20
:header-rows: 1
* - Type
- Bytes
- Description
* - ``bool_``
- 1
- Boolean (``True``, ``False``) data type
* - ``uint8``
- 1
- 8-bit unsigned integer
* - ``uint16``
- 2
- 16-bit unsigned integer
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint64``
- 8
- 64-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer
* - ``int16``
- 2
- 16-bit signed integer
* - ``int32``
- 4
- 32-bit signed integer
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
* - ``float32``
- 4
- 32-bit float
* - ``float64``
- 8
- 64-bit double
* - ``complex64``
- 8
- 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype
finfo
================================================
FILE: docs/src/python/devices_and_streams.rst
================================================
.. _devices_and_streams:
Devices and Streams
===================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
Device
Stream
default_device
set_default_device
default_stream
new_stream
set_default_stream
stream
synchronize
device_count
device_info
================================================
FILE: docs/src/python/distributed.rst
================================================
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather
send
recv
recv_like
================================================
FILE: docs/src/python/export.rst
================================================
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot
================================================
FILE: docs/src/python/fast.rst
================================================
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel
================================================
FILE: docs/src/python/fft.rst
================================================
.. _fft:
FFT
===
.. currentmodule:: mlx.core.fft
.. autosummary::
:toctree: _autosummary
fft
ifft
fft2
ifft2
fftn
ifftn
rfft
irfft
rfft2
irfft2
rfftn
irfftn
fftshift
ifftshift
================================================
FILE: docs/src/python/linalg.rst
================================================
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu
lu_factor
pinv
solve
solve_triangular
================================================
FILE: docs/src/python/memory_management.rst
================================================
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
================================================
FILE: docs/src/python/metal.rst
================================================
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
device_info
start_capture
stop_capture
================================================
FILE: docs/src/python/nn/distributed.rst
================================================
.. _nn_distributed:
Distributed
-----------
Helper Routines
^^^^^^^^^^^^^^^
The :code:`mlx.nn.layers.distributed` package contains helpful routines to
create sharded layers from existing :class:`Modules `.
.. currentmodule:: mlx.nn.layers.distributed
.. autosummary::
:toctree: _autosummary
shard_linear
shard_inplace
Layers
^^^^^^
.. currentmodule:: mlx.nn
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
AllToShardedLinear
ShardedToAllLinear
QuantizedAllToShardedLinear
QuantizedShardedToAllLinear
================================================
FILE: docs/src/python/nn/functions.rst
================================================
.. _nn_functions:
.. currentmodule:: mlx.nn
Functions
---------
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu2
relu6
selu
sigmoid
silu
softmax
softmin
softplus
softshrink
step
tanh
================================================
FILE: docs/src/python/nn/init.rst
================================================
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform
================================================
FILE: docs/src/python/nn/layers.rst
================================================
.. _layers:
.. currentmodule:: mlx.nn
Layers
------
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
AllToShardedLinear
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU
QuantizedAllToShardedLinear
QuantizedEmbedding
QuantizedLinear
QuantizedShardedToAllLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE
SELU
Sequential
ShardedToAllLinear
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample
================================================
FILE: docs/src/python/nn/losses.rst
================================================
.. _losses:
.. currentmodule:: mlx.nn.losses
Loss Functions
--------------
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss
================================================
FILE: docs/src/python/nn/module.rst
================================================
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules
================================================
FILE: docs/src/python/nn.rst
================================================
.. _nn:
.. currentmodule:: mlx.nn
Neural Networks
===============
Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks
---------------------------------
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
----------------
The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.
Parameters
^^^^^^^^^^
A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
.. code-block:: python
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
- it wraps the passed function with a function that calls :meth:`Module.update`
to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
:meth:`mlx.core.value_and_grad`
.. autosummary::
:toctree: _autosummary
value_and_grad
quantize
average_gradients
fsdp_apply_gradients
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses
nn/init
nn/distributed
================================================
FILE: docs/src/python/ops.rst
================================================
.. _ops:
Operations
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
any
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_invert
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
isinf
isnan
isneginf
isposinf
issubdtype
kron
left_shift
less
less_equal
linspace
load
log
log2
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or
logsumexp
matmul
max
maximum
mean
median
meshgrid
min
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
put_along_axis
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save
savez
savez_compressed
save_gguf
save_safetensors
sigmoid
sign
sin
sinh
slice
slice_update
softmax
sort
split
sqrt
square
squeeze
stack
std
stop_gradient
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
tile
topk
trace
transpose
tri
tril
triu
unflatten
var
view
where
zeros
zeros_like
================================================
FILE: docs/src/python/optimizers/common_optimizers.rst
================================================
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
MultiOptimizer
Muon
================================================
FILE: docs/src/python/optimizers/optimizer.rst
================================================
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update
================================================
FILE: docs/src/python/optimizers/schedulers.rst
================================================
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay
================================================
FILE: docs/src/python/optimizers.rst
================================================
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
:mod:`mlx.core` functions. A typical example involves calling
:meth:`Optimizer.update` to update a model's parameters based on the loss
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
model's parameters and the **optimizer state**.
.. code-block:: python
# Create a model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Create the gradient function and the optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the model with the gradients. So far no computation has happened.
optimizer.update(model, grads)
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm
================================================
FILE: docs/src/python/random.rst
================================================
.. _random:
Random
======
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
.. code-block:: python
for _ in range(3):
print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:
.. code-block:: python
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following `JAX's PRNG design `_
we use a splittable version of Threefry, which is a counter-based PRNG.
.. currentmodule:: mlx.core.random
.. autosummary::
:toctree: _autosummary
bernoulli
categorical
gumbel
key
normal
multivariate_normal
randint
seed
split
truncated_normal
uniform
laplace
permutation
================================================
FILE: docs/src/python/transforms.rst
================================================
.. _transforms:
Transforms
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
eval
async_eval
compile
checkpoint
custom_function
disable_compile
enable_compile
grad
value_and_grad
jvp
vjp
vmap
================================================
FILE: docs/src/python/tree_utils.rst
================================================
.. _utils:
Tree Utils
==========
In MLX we consider a python tree to be an arbitrarily nested collection of
dictionaries, lists and tuples without cycles. Functions in this module that
return python trees will be using the default python ``dict``, ``list`` and
``tuple`` but they can usually process objects that inherit from any of these.
.. note::
Dictionaries should have keys that are valid python identifiers.
.. currentmodule:: mlx.utils
.. autosummary::
:toctree: _autosummary
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce
================================================
FILE: docs/src/usage/compile.rst
================================================
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function.
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x, state):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0), state))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0), state))
In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples `_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)
================================================
FILE: docs/src/usage/distributed.rst
================================================
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI `
- A full featured and mature distributed communications library.
* - :ref:`RING `
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL `
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL `
- The backend of choice for CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs`.
Getting Started
---------------
A distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. However, when this script is run with ``python`` only
one process is launched and no distributed communication takes place. Namely,
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
.. code:: python
import mlx.core as mx
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell
$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
We can also run it on some remote hosts by providing their IPs (provided that
the script exists on all hosts and they are reachable by ssh)
.. code:: shell
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Consult the dedicated :doc:`usage guide` for more
information on using ``mlx.launch``.
Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
return **the same backend** if called without arguments or with backend set to
``any``.
The following examples aim to clarify the backend initialization logic in MLX:
.. code:: python
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
# Case 2: Initialize any backend
world = mx.distributed.init(backend="any") # equivalent to no arguments
world2 = mx.distributed.init() # same as above
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Distributed Program Examples
----------------------------
- :ref:`Data Parallelism `
- :ref:`Tensor Parallelism `
.. _ring_section:
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver are not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script `. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with JACCL
--------------------------
Starting from macOS 26.2, RDMA over thunderbolt is available and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery `_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
Fully connected mesh of four M3 Ultra.
Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).