SYMBOL INDEX (4174 symbols across 517 files) FILE: benchmarks/cpp/autograd.cpp function time_value_and_grad (line 10) | void time_value_and_grad() { function main (line 36) | int main() { FILE: benchmarks/cpp/compare_devices.cpp function time_add_op (line 9) | void time_add_op() { function main (line 25) | int main() { FILE: benchmarks/cpp/irregular_strides.cpp function time_irregular_binary_ops_1D (line 12) | void time_irregular_binary_ops_1D() { function time_irregular_binary_ops_2D (line 24) | void time_irregular_binary_ops_2D() { function time_irregular_binary_ops_3D (line 45) | void time_irregular_binary_ops_3D() { function time_irregular_binary_ops_4D (line 76) | void time_irregular_binary_ops_4D() { function time_irregular_reshape (line 116) | void time_irregular_reshape() { function time_irregular_astype_1D (line 161) | void time_irregular_astype_1D() { function time_irregular_astype_2D (line 170) | void time_irregular_astype_2D() { function main (line 188) | int main(int argc, char** argv) { FILE: benchmarks/cpp/single_ops.cpp function time_creation_ops (line 8) | void time_creation_ops() { function time_type_conversions (line 23) | void time_type_conversions() { function time_random_generation (line 45) | void time_random_generation() { function time_unary_ops (line 55) | void time_unary_ops() { function time_binary_ops (line 74) | void time_binary_ops() { function time_strided_ops (line 112) | void time_strided_ops() { function time_comparisons (line 125) | void time_comparisons() { function time_matvec (line 138) | void time_matvec() { function time_matmul (line 151) | void time_matmul() { function time_reductions (line 163) | void time_reductions() { function time_gather_scatter (line 213) | void time_gather_scatter() { function time_divmod (line 260) | void time_divmod() { function main (line 274) | int main() { FILE: benchmarks/numpy/single_ops.py function time_add (line 7) | def time_add(): function time_matmul (line 13) | def time_matmul(): function time_exp (line 19) | def time_exp(): function time_take (line 24) | def time_take(): FILE: benchmarks/numpy/time_utils.py function time_fn (line 6) | def time_fn(fn, *args): FILE: benchmarks/python/batch_matmul_bench.py function time_batch_matmul (line 13) | def time_batch_matmul(): function time_unbatch_matmul (line 33) | def time_unbatch_matmul(): FILE: benchmarks/python/blas/bench_gemm.py function bench (line 21) | def bench(f, a, b): function gemm_nn_mlx (line 33) | def gemm_nn_mlx(a, b): function gemm_nt_mlx (line 42) | def gemm_nt_mlx(a, b): function gemm_tn_mlx (line 51) | def gemm_tn_mlx(a, b): function gemm_tt_mlx (line 60) | def gemm_tt_mlx(a, b): function gemm_nn_torch (line 70) | def gemm_nn_torch(a, b): function gemm_nt_torch (line 80) | def gemm_nt_torch(a, b): function gemm_tn_torch (line 90) | def gemm_tn_torch(a, b): function gemm_tt_torch (line 100) | def gemm_tt_torch(a, b): function bench_shape (line 109) | def bench_shape(B, M, N, K, np_dtype, transpose="nn"): function get_gflop_count (line 157) | def get_gflop_count(B, M, N, K): FILE: benchmarks/python/blas/bench_gemv.py function bench (line 36) | def bench(f, m, v): function gemv_mlx (line 48) | def gemv_mlx(m, v): function gemv_t_mlx (line 57) | def gemv_t_mlx(m, v): function gemv_torch (line 67) | def gemv_torch(m, v): function gemv_t_torch (line 77) | def gemv_t_torch(m, v): function bench_lens (line 86) | def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False): function get_gflop_count (line 123) | def get_gflop_count(in_vec_len, out_vec_len): function get_gbyte_size (line 129) | def get_gbyte_size(in_vec_len, out_vec_len, np_dtype): function bench_with_in_len (line 135) | def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose): function bench_with_out_len (line 166) | def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): FILE: benchmarks/python/comparative/bench_mlx.py function int_or_list (line 13) | def int_or_list(x): function none_or_list (line 20) | def none_or_list(x): function dtype_from_str (line 27) | def dtype_from_str(x): function bench (line 37) | def bench(f, *args): function matmul_square (line 48) | def matmul_square(x): function matmul (line 56) | def matmul(x, y): function _quant_matmul (line 63) | def _quant_matmul(x, w, s, b, transpose, group_size, bits): function conv1d (line 120) | def conv1d(x, y): function conv2d (line 127) | def conv2d(x, y): function binary (line 134) | def binary(op, x, y): function reduction (line 140) | def reduction(op, axis, x): function sum_and_add (line 147) | def sum_and_add(axis, x, y): function softmax (line 154) | def softmax(axis, x): function softmax_fused (line 163) | def softmax_fused(axis, x): function relu (line 171) | def relu(x): function leaky_relu (line 178) | def leaky_relu(x: mx.array): function prelu (line 185) | def prelu(x: mx.array): function softplus (line 192) | def softplus(x: mx.array): function mish (line 199) | def mish(x: mx.array): function leaky_relu (line 206) | def leaky_relu(x): function elu (line 213) | def elu(x): function relu6 (line 220) | def relu6(x): function softplus (line 227) | def softplus(x): function celu (line 234) | def celu(x): function log_sigmoid (line 241) | def log_sigmoid(x): function scalar_mult (line 248) | def scalar_mult(x): function cross_entropy (line 255) | def cross_entropy(targets, x): function logsumexp (line 265) | def logsumexp(axis, x): function linear (line 272) | def linear(w, b, x): function linear_fused (line 279) | def linear_fused(w, b, x): function rope (line 286) | def rope(x): function concatenate (line 307) | def concatenate(axis, x, y): function cumsum (line 314) | def cumsum(axis, x): function sort (line 321) | def sort(axis, x): function topk (line 328) | def topk(axis, x): function step_function (line 336) | def step_function(x): function selu (line 343) | def selu(x): FILE: benchmarks/python/comparative/bench_torch.py function int_or_list (line 12) | def int_or_list(x): function none_or_list (line 19) | def none_or_list(x): function dtype_from_str (line 26) | def dtype_from_str(x): function bench (line 36) | def bench(f, *args): function sync_if_needed (line 47) | def sync_if_needed(x): function matmul_square (line 55) | def matmul_square(x): function matmul (line 63) | def matmul(x, y): function conv1d (line 71) | def conv1d(x, y): function conv2d (line 81) | def conv2d(x, y): function binary (line 91) | def binary(op, x, y): function reduction (line 98) | def reduction(op, axis, x): function sum_and_add (line 106) | def sum_and_add(axis, x, y): function softmax (line 114) | def softmax(axis, x): function softmax_fused (line 124) | def softmax_fused(axis, x): function relu (line 132) | def relu(x): function leaky_relu (line 140) | def leaky_relu(x): function elu (line 148) | def elu(x): function celu (line 156) | def celu(x): function relu6 (line 164) | def relu6(x): function softplus (line 172) | def softplus(x): function log_sigmoid (line 180) | def log_sigmoid(x): function prelu (line 188) | def prelu(x: torch.Tensor) -> torch.Tensor: function mish (line 196) | def mish(x: torch.Tensor) -> torch.Tensor: function scalar_mult (line 204) | def scalar_mult(x): function cross_entropy (line 212) | def cross_entropy(targets, x): function logsumexp (line 220) | def logsumexp(axis, x): function linear_fused (line 228) | def linear_fused(w, b, x): function linear (line 236) | def linear(w, b, x): function rope (line 244) | def rope(x): function concatenate (line 265) | def concatenate(axis, x, y): function cumsum (line 273) | def cumsum(axis, x): function sort (line 281) | def sort(axis, x): function topk (line 289) | def topk(axis, x): function step_function (line 298) | def step_function(x): function selu (line 306) | def selu(x): FILE: benchmarks/python/comparative/compare.py function run_or_raise (line 14) | def run_or_raise(*args, **kwargs): function compare (line 24) | def compare(args): function compare_mlx_dtypes (line 31) | def compare_mlx_dtypes(args, dt1, dt2): function make_regex_search (line 38) | def make_regex_search(regexes): function make_predicate (line 47) | def make_predicate(positive_filter, negative_filter): FILE: benchmarks/python/compile_bench.py function bench_gelu (line 11) | def bench_gelu(): function bench_layernorm (line 52) | def bench_layernorm(): FILE: benchmarks/python/conv1d_bench.py function bench (line 19) | def bench(f, a, b): function make_mx_conv_1D (line 31) | def make_mx_conv_1D(strides=1, padding=0, groups=1): function make_pt_conv_1D (line 43) | def make_pt_conv_1D(strides=1, padding=0, groups=1): function bench_shape (line 56) | def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups): FILE: benchmarks/python/conv2d_bench_cpu.py function bench (line 15) | def bench(f, a, b): function make_mx_conv_2D (line 26) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_2D (line 38) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 50) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): FILE: benchmarks/python/conv2d_train_bench_cpu.py function bench_mlx (line 9) | def bench_mlx(steps: int = 20) -> float: function bench_torch (line 73) | def bench_torch(steps: int = 20) -> float: function main (line 128) | def main(): FILE: benchmarks/python/conv2d_transpose_bench_cpu.py function bench (line 14) | def bench(f, a, b): function make_mx_conv_transpose_2D (line 25) | def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_transpose_2D (line 39) | def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 53) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): FILE: benchmarks/python/conv3d_bench.py function bench (line 13) | def bench(f, a, b, b_prime): function make_mx_conv_3D (line 25) | def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): function make_pt_conv_3D (line 37) | def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): function bench_shape (line 50) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ... FILE: benchmarks/python/conv3d_bench_cpu.py function bench (line 15) | def bench(f, a, b): function make_mx_conv_3D (line 26) | def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_3D (line 38) | def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 50) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ... FILE: benchmarks/python/conv3d_train_bench_cpu.py function bench_mlx (line 9) | def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float: function bench_torch (line 73) | def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float: function main (line 128) | def main(): FILE: benchmarks/python/conv3d_transpose_bench_cpu.py function bench (line 15) | def bench(f, a, b): function make_mx_conv_3D (line 26) | def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): function make_pt_conv_3D (line 40) | def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): function bench_shape (line 54) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ... FILE: benchmarks/python/conv_bench.py function bench (line 19) | def bench(f, a, b): function make_mx_conv_2D (line 31) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_2D (line 43) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 56) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): FILE: benchmarks/python/conv_transpose_bench.py function bench (line 16) | def bench(f, a, b): function make_mx_conv_transpose_2D (line 28) | def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_transpose_2D (line 42) | def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 57) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): FILE: benchmarks/python/conv_unaligned_bench.py function bench (line 13) | def bench(f, a, b): function make_mx_conv_2D (line 25) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function make_pt_conv_2D (line 37) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): function bench_shape (line 50) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): FILE: benchmarks/python/distributed_bench.py function time_fn (line 13) | def time_fn(fn, *args, **kwargs): function time_all_sum (line 37) | def time_all_sum(): FILE: benchmarks/python/einsum_bench.py function timeit (line 9) | def timeit(fn, its=100, args=[]): function time_little_einsum_path (line 19) | def time_little_einsum_path(): function time_big_einsum_path (line 33) | def time_big_einsum_path(): function time_attention (line 55) | def time_attention(): FILE: benchmarks/python/fft_bench.py function bandwidth_gb (line 14) | def bandwidth_gb(runtime_ms, system_size): function run_bench (line 21) | def run_bench(system_size, fft_sizes, backend="mlx", dim=1): function time_fft (line 62) | def time_fft(): FILE: benchmarks/python/gather_bench.py function benchmark_gather_mlx (line 10) | def benchmark_gather_mlx(x_shape, idx_shape): function benchmark_gather_torch (line 21) | def benchmark_gather_torch(x_shape, idx_shape, device): FILE: benchmarks/python/gather_mm_bench.py function gather_sort (line 13) | def gather_sort(x, indices): function scatter_unsort (line 21) | def scatter_unsort(x, inv_order, shape=None): function gather_mm_simulate (line 28) | def gather_mm_simulate(x, w, indices): function time_gather_mm (line 37) | def time_gather_mm(): FILE: benchmarks/python/gather_qmm_bench.py function gather_sort (line 13) | def gather_sort(x, indices): function scatter_unsort (line 21) | def scatter_unsort(x, inv_order, shape=None): function gather_mm_simulate (line 28) | def gather_mm_simulate(x, w, indices): function time_gather_qmm (line 43) | def time_gather_qmm(): FILE: benchmarks/python/hadamard_bench.py function had (line 12) | def had(x): function copy (line 17) | def copy(x): function run (line 22) | def run(dtype): FILE: benchmarks/python/large_gemm_bench.py function bench_mlx (line 14) | def bench_mlx(a, b): function bench_torch (line 29) | def bench_torch(a, b): function check_correctness (line 45) | def check_correctness(out_mx, out_pt, rtol, M, N, K): function bench_gemm (line 56) | def bench_gemm(M, N, K, dtype, rtol): FILE: benchmarks/python/layer_norm_bench.py function layer_norm (line 10) | def layer_norm(x, w, b, eps): function time_layer_norm (line 23) | def time_layer_norm(N, dt): FILE: benchmarks/python/masked_scatter.py function get_device_name (line 28) | def get_device_name(): function _power_of_two_formatter (line 62) | def _power_of_two_formatter(value, _position): function torch_sync (line 71) | def torch_sync(): function masked_scatter_mlx (line 78) | def masked_scatter_mlx(self_arr, mask_arr, src_arr): function masked_scatter_torch (line 89) | def masked_scatter_torch(self_tensor, mask_tensor, src_tensor): function measure (line 99) | def measure(fn): function bytes_touched (line 109) | def bytes_touched(length, true_count, item_size): function build_case (line 116) | def build_case(length, density, np_dtype, torch_dtype): function bench_case (line 148) | def bench_case(length, density, dtype): function plot_density (line 174) | def plot_density(ax_perf, ax_speedup, density, dtype): function main (line 208) | def main(): FILE: benchmarks/python/rms_norm_bench.py function rms_norm (line 8) | def rms_norm(x, w, eps): function time_rms_norm (line 18) | def time_rms_norm(): FILE: benchmarks/python/rope_bench.py function time_rope (line 8) | def time_rope(): FILE: benchmarks/python/scatter_bench.py function benchmark_scatter_mlx (line 10) | def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): function benchmark_scatter_torch (line 25) | def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): FILE: benchmarks/python/sdpa_bench.py function bench (line 20) | def bench(f, *args): function prepare_inputs (line 31) | def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): function mlx_ref_attn (line 58) | def mlx_ref_attn(q, k, v, scale=1.0, mask=None): function mlx_fused_attn (line 104) | def mlx_fused_attn(q, k, v, scale, mask): function do_attention (line 108) | def do_attention(f, q, k, v, scale, mask=None, transpose=False): function do_attention_bench (line 119) | def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False): function bench_shape (line 129) | def bench_shape( function get_gflop_count (line 158) | def get_gflop_count(B, M, N, K): FILE: benchmarks/python/sdpa_vector_bench.py function upproject (line 16) | def upproject(x, w): function attention (line 23) | def attention(q, k, v, mask=None, w=None): function sdpa (line 45) | def sdpa(q, k, v, mask=None, w=None): function time_self_attention_primitives (line 52) | def time_self_attention_primitives(): function time_self_attention_sdpa (line 62) | def time_self_attention_sdpa(): function time_self_attention_sdpa_with_mask (line 72) | def time_self_attention_sdpa_with_mask(): FILE: benchmarks/python/segmented_mm_bench.py function parse_cases (line 16) | def parse_cases(cases): function make_segments (line 24) | def make_segments(k, num_segments, pattern, seed): function numpy_segmented_mm_ref (line 35) | def numpy_segmented_mm_ref(a, b, segments): function mlx_segmented_mm_loop (line 43) | def mlx_segmented_mm_loop(a, b, segments): function bench_mlx (line 52) | def bench_mlx(a, b, segments, warmup, iters): function bench_mlx_loop (line 67) | def bench_mlx_loop(a, b, segments, warmup, iters): function print_table (line 82) | def print_table(headers, rows): function main (line 102) | def main(): FILE: benchmarks/python/single_ops.py function time_add (line 9) | def time_add(): function time_matmul (line 40) | def time_matmul(): function time_maximum (line 47) | def time_maximum(): function time_max (line 54) | def time_max(): function time_min (line 61) | def time_min(): function time_negative (line 68) | def time_negative(): function time_exp (line 80) | def time_exp(): function time_logsumexp (line 86) | def time_logsumexp(): function time_take (line 92) | def time_take(): function time_reshape_transposed (line 104) | def time_reshape_transposed(): FILE: benchmarks/python/slice_update_bench.py function benchmark_slice_update_mlx (line 10) | def benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtyp... function benchmark_slice_update_torch (line 32) | def benchmark_slice_update_torch( FILE: benchmarks/python/synchronize_bench.py function timeit (line 8) | def timeit(fn, a): function all_reduce_benchmark (line 23) | def all_reduce_benchmark(): function all_gather_benchmark (line 39) | def all_gather_benchmark(): FILE: benchmarks/python/time_utils.py function time_fn (line 8) | def time_fn(fn, *args, **kwargs): function measure_runtime (line 29) | def measure_runtime(fn, **kwargs): FILE: docs/src/conf.py function setup (line 71) | def setup(app): FILE: examples/cmake_project/example.cpp function main (line 9) | int main() { FILE: examples/cpp/distributed.cpp function main (line 9) | int main() { FILE: examples/cpp/linear_regression.cpp function main (line 15) | int main() { FILE: examples/cpp/logistic_regression.cpp function main (line 15) | int main() { FILE: examples/cpp/metal_capture.cpp function main (line 10) | int main() { FILE: examples/cpp/timer.h function namespace (line 7) | namespace timer { FILE: examples/cpp/tutorial.cpp function array_basics (line 10) | void array_basics() { function automatic_differentiation (line 81) | void automatic_differentiation() { function main (line 96) | int main() { FILE: examples/export/eval_mlp.cpp function main (line 8) | int main() { FILE: examples/export/eval_mlp.py class MLP (line 8) | class MLP(nn.Module): method __init__ (line 11) | def __init__( method __call__ (line 21) | def __call__(self, x): function forward (line 39) | def forward(x): FILE: examples/export/train_mlp.cpp function main (line 8) | int main() { FILE: examples/export/train_mlp.py class MLP (line 9) | class MLP(nn.Module): method __init__ (line 12) | def __init__( method __call__ (line 22) | def __call__(self, x): function init (line 34) | def init(): function loss_fn (line 51) | def loss_fn(params, X, y): function step (line 55) | def step(*inputs): FILE: examples/extensions/axpby/axpby.cpp type my_ext (line 18) | namespace my_ext { function current_binary_dir (line 22) | std::string current_binary_dir() { function axpby (line 44) | mx::array axpby( function axpby_impl (line 82) | void axpby_impl( FILE: examples/extensions/axpby/axpby.h function namespace (line 10) | namespace my_ext { FILE: examples/extensions/bindings.cpp function NB_MODULE (line 11) | NB_MODULE(_ext, m) { FILE: examples/python/linear_regression.py function loss_fn (line 26) | def loss_fn(w): FILE: examples/python/logistic_regression.py function loss_fn (line 26) | def loss_fn(w): FILE: examples/python/qqmm.py function ulp_bf16_at (line 14) | def ulp_bf16_at(x): function test_qqmm (line 22) | def test_qqmm(): function test_qqmm_vjp (line 80) | def test_qqmm_vjp(): FILE: mlx/3rdparty/pocketfft.h function namespace (line 91) | namespace pocketfft { function cmplx (line 302) | static cmplx calc(size_t x, size_t n, Thigh ang) type util (line 369) | struct util // hack to avoid duplicate symbols function POCKETFFT_NOINLINE (line 383) | static POCKETFFT_NOINLINE double cost_guess (size_t n) function POCKETFFT_NOINLINE (line 401) | static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) function POCKETFFT_NOINLINE (line 430) | static POCKETFFT_NOINLINE size_t good_size_real(size_t n) function prod (line 456) | static size_t prod(const shape_t &shape) function shutdown (line 753) | void shutdown() function restart (line 759) | void restart() function thread_pool (line 766) | inline thread_pool & get_pool() type fctdata (line 829) | struct fctdata function add_factor (line 839) | void add_factor(size_t factor) function twsize (line 1498) | size_t twsize() const function comp_twiddle (line 1512) | void comp_twiddle() type fctdata (line 1555) | struct fctdata function add_factor (line 1565) | void add_factor(size_t factor) function twsize (line 2297) | size_t twsize() const function comp_twiddle (line 2310) | void comp_twiddle() function exec (line 2527) | void exec(T c[], T0 fct, bool fwd) const function exec (line 2547) | void exec(T c[], T0 fct, bool ortho, function exec (line 2578) | void exec(T c[], T0 fct, function exec (line 2609) | void exec(T c[], T0 fct, bool ortho, function exec (line 2687) | void exec(T c[], T0 fct, function find_in_cache (line 2780) | auto find_in_cache = [&]() -> std::shared_ptr function class (line 2822) | class arr_info function advance_i (line 2870) | void advance_i() function advance (line 2916) | void advance(size_t n) function iofs (line 2927) | ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*str_i; } function iofs (line 2928) | ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*... function oofs (line 2929) | ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*str_o; } function oofs (line 2930) | ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*... function class (line 2938) | class simple_iter function class (line 2966) | class rev_iter type ExecC2C (line 3166) | struct ExecC2C type ExecHartley (line 3211) | struct ExecHartley type ExecDcst (line 3223) | struct ExecDcst type ExecR2R (line 3366) | struct ExecR2R function ExecDcst (line 3407) | const ExecDcst exec{ortho, type, true}; function ExecDcst (line 3425) | const ExecDcst exec{ortho, type, false}; function newaxes (line 3461) | auto newaxes = shape_t{axes.begin(), --axes.end()}; function newaxes (line 3499) | auto newaxes = shape_t{axes.begin(), --axes.end()}; FILE: mlx/allocator.h function namespace (line 9) | namespace mlx::core::allocator { FILE: mlx/array.cpp type mlx::core (line 11) | namespace mlx::core { function array (line 60) | array array::unsafe_weak_copy(const array& other) { FILE: mlx/array.h function namespace (line 16) | namespace mlx::core { function ArrayIterator (line 157) | struct MLX_API ArrayIterator { function Data (line 231) | struct Data { type Flags (line 248) | struct Flags { function set_siblings (line 313) | void set_siblings(std::vector siblings, uint16_t position) { function buffer_size (line 359) | size_t buffer_size() const { type Status (line 387) | enum Status { function is_available (line 404) | bool is_available() const; function set_status (line 414) | void set_status(Status s) const { function attach_event (line 424) | void attach_event(Event e) const { function set_tracer (line 433) | void set_tracer(bool is_tracer) { function is_tracer (line 437) | bool is_tracer() const; type MLX_API (line 468) | struct MLX_API function offset (line 489) | int64_t offset{0}; FILE: mlx/backend/common/binary.h function namespace (line 9) | namespace mlx::core { FILE: mlx/backend/common/broadcasting.cpp type mlx::core (line 5) | namespace mlx::core { function broadcast (line 7) | void broadcast(const array& in, array& out) { FILE: mlx/backend/common/broadcasting.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/common/buffer_cache.h function T (line 30) | T* reuse_from_cache(size_t size) { function recycle_to_cache (line 48) | void recycle_to_cache(T* buf) { function release_cached_buffers (line 58) | int release_cached_buffers(size_t min_bytes_to_free) { function clear (line 87) | int clear() { function BufferHolder (line 150) | BufferHolder* tail_{nullptr}; FILE: mlx/backend/common/common.cpp type mlx::core (line 8) | namespace mlx::core { function prepare_reshape (line 147) | std::pair prepare_reshape(const array& in, const array&... function shared_buffer_reshape (line 184) | void shared_buffer_reshape( FILE: mlx/backend/common/compiled.cpp type mlx::core (line 7) | namespace mlx::core { function print_constant (line 9) | void print_constant(std::ostream& os, const array& x) { function get_type_string (line 47) | std::string get_type_string(Dtype d) { function compiled_check_contiguity (line 85) | bool compiled_check_contiguity( function compiled_allocate_outputs (line 113) | void compiled_allocate_outputs( function compiled_collapse_contiguous_dims (line 173) | std::tuple> compiled_collapse_contig... function compiled_use_large_index (line 224) | bool compiled_use_large_index( FILE: mlx/backend/common/compiled.h function namespace (line 10) | namespace mlx::core { function is_scalar (line 47) | inline bool is_scalar(const array& x) { FILE: mlx/backend/common/copy.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/common/hadamard.h function namespace (line 9) | namespace mlx::core { FILE: mlx/backend/common/load.cpp function swap_endianness (line 12) | void swap_endianness(uint8_t* data_bytes, size_t N) { type mlx::core (line 28) | namespace mlx::core { FILE: mlx/backend/common/matmul.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/common/quantized.h function namespace (line 3) | namespace mlx::core { FILE: mlx/backend/common/reduce.cpp type mlx::core (line 5) | namespace mlx::core { function shapes_without_reduction_axes (line 7) | std::pair shapes_without_reduction_axes( function shapes_without_reduction_axes (line 20) | std::pair shapes_without_reduction_axes( function ReductionPlan (line 29) | ReductionPlan get_reduction_plan(const array& x, const std::vector prepare_slice( function shared_buffer_slice (line 20) | void shared_buffer_slice( function slice (line 38) | void slice( FILE: mlx/backend/common/slicing.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/common/ternary.h function namespace (line 8) | namespace mlx::core { FILE: mlx/backend/common/unary.h function namespace (line 8) | namespace mlx::core { FILE: mlx/backend/common/utils.cpp type mlx::core (line 7) | namespace mlx::core { function current_binary_dir (line 9) | std::filesystem::path current_binary_dir() { function collapse_contiguous_dims (line 20) | std::tuple> collapse_contiguous_dims( function collapse_contiguous_dims (line 83) | std::pair collapse_contiguous_dims( function collapse_contiguous_dims (line 111) | std::pair collapse_contiguous_dims( function Dims (line 117) | Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* =... function Dims (line 148) | Dims get_2d_grid_dims_common(const Shape& shape, const Strides& stride... function Dims (line 173) | Dims get_2d_grid_dims_common( function get_grid_and_block_common (line 221) | std::pair get_grid_and_block_common(int dim0, int dim1, in... FILE: mlx/backend/common/utils.h function namespace (line 11) | namespace mlx::core { function loc (line 189) | int64_t loc{0}; function is_donatable (line 218) | inline bool is_donatable(const array& in, const array& out) { FILE: mlx/backend/cpu/arange.h function namespace (line 8) | namespace mlx::core { FILE: mlx/backend/cpu/arg_reduce.cpp type mlx::core (line 9) | namespace mlx::core { function arg_reduce (line 14) | void arg_reduce(const array& in, array& out, const OpT& op, int axis) { function arg_reduce_dispatch (line 35) | void arg_reduce_dispatch( FILE: mlx/backend/cpu/binary.cpp type mlx::core (line 15) | namespace mlx::core { FILE: mlx/backend/cpu/binary.h function namespace (line 13) | namespace mlx::core { function binary_op_dispatch_dims (line 109) | void binary_op_dispatch_dims( FILE: mlx/backend/cpu/binary_two.h function namespace (line 8) | namespace mlx::core { FILE: mlx/backend/cpu/cholesky.cpp type mlx::core (line 10) | namespace mlx::core { function cholesky_impl (line 13) | void cholesky_impl(const array& a, array& factor, bool upper, Stream s... FILE: mlx/backend/cpu/compiled.cpp type mlx::core (line 20) | namespace mlx::core { type CompilerCache (line 22) | struct CompilerCache { type DLib (line 23) | struct DLib { method DLib (line 24) | DLib(const std::string& libname) { function CompilerCache (line 44) | static CompilerCache& cache() { type DLib (line 23) | struct DLib { method DLib (line 24) | DLib(const std::string& libname) { type detail (line 51) | namespace detail { function compile_available_for_device (line 52) | bool compile_available_for_device(const Device& device) { function build_kernel (line 150) | inline void build_kernel( FILE: mlx/backend/cpu/conv.cpp type mlx::core (line 12) | namespace mlx::core { function slow_conv_1D (line 21) | void slow_conv_1D( function slow_conv_2D (line 110) | void slow_conv_2D( function slow_conv_3D (line 358) | void slow_conv_3D( function dispatch_slow_conv_1D (line 673) | void dispatch_slow_conv_1D( function dispatch_slow_conv_2D (line 726) | void dispatch_slow_conv_2D( function dispatch_slow_conv_3D (line 779) | void dispatch_slow_conv_3D( function flip_spatial_dims_inplace (line 837) | void flip_spatial_dims_inplace( function explicit_gemm_conv_1D_cpu (line 856) | void explicit_gemm_conv_1D_cpu( function explicit_gemm_conv_ND_cpu (line 999) | void explicit_gemm_conv_ND_cpu( function conv_1D_cpu (line 1171) | void conv_1D_cpu( function conv_2D_cpu (line 1213) | void conv_2D_cpu( function conv_3D_cpu (line 1251) | void conv_3D_cpu( FILE: mlx/backend/cpu/copy.cpp type mlx::core (line 11) | namespace mlx::core { function copy_single (line 16) | void copy_single(const array& src, array& dst) { function copy_vector (line 25) | void copy_vector(const array& src, array& dst) { function copy_dims (line 33) | inline void copy_dims( function copy_general_general (line 57) | void copy_general_general( function copy_general_general (line 128) | inline void copy_general_general(const array& src, array& dst) { function copy_general (line 142) | void copy_general( function copy_general (line 165) | inline void copy_general(const array& src, array& dst) { function copy (line 179) | void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { function copy (line 197) | void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { function copy_inplace_dispatch (line 245) | inline void copy_inplace_dispatch( function copy_cpu_inplace (line 298) | void copy_cpu_inplace( function copy_cpu (line 312) | void copy_cpu(const array& src, array& dst, CopyType ctype, Stream str... function copy_cpu_inplace (line 325) | void copy_cpu_inplace( function array (line 380) | array contiguous_copy_cpu(const array& arr, Stream stream) { FILE: mlx/backend/cpu/copy.h function namespace (line 11) | namespace mlx::core { FILE: mlx/backend/cpu/device_info.cpp type mlx::core::cpu (line 15) | namespace mlx::core::cpu { function get_cpu_architecture (line 20) | std::string get_cpu_architecture() { function get_cpu_name (line 51) | std::string get_cpu_name() { function is_available (line 96) | bool is_available() { function device_count (line 100) | int device_count() { FILE: mlx/backend/cpu/device_info.h function namespace (line 9) | namespace mlx::core::cpu { FILE: mlx/backend/cpu/distributed.cpp type mlx::core::distributed (line 10) | namespace mlx::core::distributed { function ensure_row_contiguous (line 12) | std::pair ensure_row_contiguous(const array& arr, Stream ... FILE: mlx/backend/cpu/eig.cpp type mlx::core (line 11) | namespace mlx::core { function complex64_t (line 16) | complex64_t to_complex(T r, T i) { type EigWork (line 21) | struct EigWork {} type EigWork< T, typename std::enable_if::value>::type> (line 24) | struct EigWork< method EigWork (line 36) | EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) method run (line 65) | void run(T* a, O* values, O* vectors) { type EigWork> (line 116) | struct EigWork> { method EigWork (line 129) | EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) method run (line 155) | void run(T* a, T* values, T* vectors) { function eig_impl (line 177) | void eig_impl( FILE: mlx/backend/cpu/eigh.cpp type mlx::core (line 11) | namespace mlx::core { type EighWork (line 16) | struct EighWork {} type EighWork< T, typename std::enable_if::value>::type> (line 19) | struct EighWork< method EighWork (line 32) | EighWork(char jobz_, char uplo_, int N_) method run (line 54) | void run(T* vectors, T* values) { type EighWork> (line 71) | struct EighWork> { method EighWork (line 84) | EighWork(char jobz_, char uplo_, int N_) method run (line 111) | void run(T* vectors, R* values) { function eigh_impl (line 143) | void eigh_impl( FILE: mlx/backend/cpu/encoder.cpp type mlx::core::cpu (line 5) | namespace mlx::core::cpu { function CommandEncoder (line 7) | CommandEncoder& get_command_encoder(Stream stream) { FILE: mlx/backend/cpu/encoder.h function namespace (line 10) | namespace mlx::core::cpu { function num_ops_ (line 62) | int num_ops_{0}; FILE: mlx/backend/cpu/eval.cpp type mlx::core::cpu (line 8) | namespace mlx::core::cpu { function eval (line 10) | void eval(array& arr) { FILE: mlx/backend/cpu/eval.h function namespace (line 8) | namespace mlx::core::cpu { FILE: mlx/backend/cpu/fft.cpp type mlx::core (line 10) | namespace mlx::core { FILE: mlx/backend/cpu/gemm.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/cpu/gemms/bnns.cpp type mlx::core (line 9) | namespace mlx::core { function BNNSDataType (line 15) | constexpr BNNSDataType to_bnns_dtype() { function BNNSDataType (line 19) | constexpr BNNSDataType to_bnns_dtype() { function BNNSDataType (line 24) | constexpr BNNSDataType to_bnns_dtype() { function matmul_bnns (line 29) | void matmul_bnns( FILE: mlx/backend/cpu/gemms/cblas.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/cpu/gemms/simd_bf16.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/cpu/gemms/simd_fp16.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/cpu/gemms/simd_gemm.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/cpu/hadamard.cpp type mlx::core (line 10) | namespace mlx::core { function hadamard_n (line 14) | void hadamard_n(T* out, int n, int m, float scale, size_t size) { function hadamard_m (line 40) | void hadamard_m(T* out, int n, int m, float scale, size_t size) { function hadamard (line 78) | void hadamard(array& out, int n, int m, float scale, Stream stream) { FILE: mlx/backend/cpu/indexing.cpp type mlx::core (line 16) | namespace mlx::core { function offset_neg_idx (line 19) | inline size_t offset_neg_idx(IdxT idx, size_t size) { function offset_neg_idx (line 24) | inline size_t offset_neg_idx(uint32_t idx, size_t) { type None (line 28) | struct None { type Sum (line 34) | struct Sum { type Prod (line 41) | struct Prod { type Max (line 48) | struct Max { type Min (line 55) | struct Min { function gather (line 63) | void gather( function dispatch_gather (line 150) | void dispatch_gather( function gather_axis (line 258) | void gather_axis( function dispatch_gather_axis (line 304) | void dispatch_gather_axis( function scatter (line 402) | void scatter( function dispatch_scatter_inds (line 446) | void dispatch_scatter_inds( function dispatch_scatter (line 472) | void dispatch_scatter( function scatter_axis (line 586) | void scatter_axis(array& out, const array idx, const array& upd, int a... function dispatch_scatter_axis_op (line 628) | void dispatch_scatter_axis_op( function dispatch_scatter_axis (line 645) | void dispatch_scatter_axis( function masked_scatter_impl (line 754) | void masked_scatter_impl(const array& mask, const array& src, array& o... function slice_update_impl (line 858) | void slice_update_impl( FILE: mlx/backend/cpu/inverse.cpp type mlx::core (line 9) | namespace mlx::core { function general_inv (line 12) | void general_inv(T* inv, int N) { function tri_inv (line 72) | void tri_inv(T* inv, int N, bool upper) { function inverse_impl (line 106) | void inverse_impl( FILE: mlx/backend/cpu/jit_compiler.cpp type mlx::core (line 11) | namespace mlx::core { function str_split (line 18) | std::vector str_split(const std::string& str, char delimi... type VisualStudioInfo (line 29) | struct VisualStudioInfo { method VisualStudioInfo (line 30) | VisualStudioInfo() { function VisualStudioInfo (line 80) | const VisualStudioInfo& GetVisualStudioInfo() { method VisualStudioInfo (line 30) | VisualStudioInfo() { FILE: mlx/backend/cpu/jit_compiler.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/cpu/logsumexp.cpp type mlx::core (line 12) | namespace mlx::core { function logsumexp (line 19) | void logsumexp(const array& in, array& out, Stream stream) { FILE: mlx/backend/cpu/luf.cpp type mlx::core (line 11) | namespace mlx::core { function luf_impl (line 14) | void luf_impl( FILE: mlx/backend/cpu/masked_mm.cpp type mlx::core (line 13) | namespace mlx::core { function mask_matrix (line 18) | inline void mask_matrix( function segmented_mm (line 57) | inline void segmented_mm( FILE: mlx/backend/cpu/matmul.cpp type mlx::core (line 12) | namespace mlx::core { function matmul_dispatch (line 15) | void matmul_dispatch( function matmul_general (line 69) | void matmul_general( FILE: mlx/backend/cpu/primitives.cpp type mlx::core (line 19) | namespace mlx::core { function reshape (line 21) | void reshape(const array& in, array& out) { function compute_dynamic_offset (line 31) | static std::pair compute_dynamic_offset( FILE: mlx/backend/cpu/qrf.cpp type mlx::core (line 9) | namespace mlx::core { function qrf_impl (line 12) | void qrf_impl(const array& a, array& q, array& r, Stream stream) { FILE: mlx/backend/cpu/quantized.cpp type mlx::core (line 14) | namespace mlx::core { function array (line 18) | array ensure_row_contiguous( function T (line 50) | static inline T dequantize_scale(uint8_t s) { function extract_bits (line 65) | void extract_bits(const uint8_t* w_in, T* w_out) { function _qmm (line 97) | void _qmm( function _qmm_t (line 155) | void _qmm_t( function extract_bits_simd (line 215) | simd::Simd extract_bits_simd(const uint32_t* w) { function _qmm_t_simd (line 240) | void _qmm_t_simd( function _qmm_dispatch_transpose (line 287) | void _qmm_dispatch_transpose( function _qmm_dispatch_group (line 310) | void _qmm_dispatch_group( function _qmm_dispatch_typed (line 341) | void _qmm_dispatch_typed( function _qmm_dispatch_typed (line 384) | void _qmm_dispatch_typed( function _qmm_dispatch (line 421) | void _qmm_dispatch( function fp_qmm (line 450) | void fp_qmm( function fp_qmm_t (line 492) | void fp_qmm_t( function fp_extract_bits_simd (line 536) | simd::Simd fp_extract_bits_simd(const uint32_t* w) { function fp_qmm_t_simd (line 558) | void fp_qmm_t_simd( function fp_qmm_dispatch_transpose (line 603) | void fp_qmm_dispatch_transpose( function fp_qmm_dispatch_mode (line 625) | void fp_qmm_dispatch_mode( function fp_qmm_dispatch_typed (line 656) | void fp_qmm_dispatch_typed( function fp_qmm_dispatch (line 673) | void fp_qmm_dispatch( function _bs_qmm_dispatch_typed (line 701) | void _bs_qmm_dispatch_typed( function _bs_qmm_dispatch (line 749) | void _bs_qmm_dispatch( function fp_bs_qmm_dispatch_mode (line 806) | void fp_bs_qmm_dispatch_mode( function fp_bs_qmm_dispatch_typed (line 847) | void fp_bs_qmm_dispatch_typed( function fp_bs_qmm_dispatch (line 869) | void fp_bs_qmm_dispatch( function to_fp8_e8m0 (line 1049) | uint8_t to_fp8_e8m0(float x) { function to_fp4_e2m1 (line 1064) | uint8_t to_fp4_e2m1(float x) { function fp_quantize_dequantize (line 1094) | void fp_quantize_dequantize( function dispatch_quantize_dequantize (line 1131) | void dispatch_quantize_dequantize( function quantize (line 1149) | void quantize( function dispatch_quantize (line 1214) | void dispatch_quantize( FILE: mlx/backend/cpu/reduce.cpp type mlx::core (line 12) | namespace mlx::core { type Limits (line 15) | struct Limits { type Limits (line 50) | struct Limits { function strided_reduce (line 72) | void strided_reduce( function contiguous_reduce (line 99) | void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U ... function nd_loop (line 115) | void nd_loop( function reduction_op (line 139) | void reduction_op( type AndReduce (line 264) | struct AndReduce { type OrReduce (line 290) | struct OrReduce { type MaxReduce (line 316) | struct MaxReduce { method T (line 318) | T operator()(T y, T x) { type MinReduce (line 341) | struct MinReduce { method T (line 343) | T operator()(T y, T x) { type SumReduce (line 366) | struct SumReduce { method U (line 368) | U operator()(U y, T x) { method T (line 378) | T operator()(simd::Simd x) { type ProdReduce (line 383) | struct ProdReduce { method U (line 385) | U operator()(U y, T x) { method T (line 395) | T operator()(simd::Simd x) { function reduce_dispatch_and_or (line 401) | void reduce_dispatch_and_or( function reduce_dispatch_sum_prod (line 414) | void reduce_dispatch_sum_prod( function reduce_dispatch_min_max (line 435) | void reduce_dispatch_min_max( FILE: mlx/backend/cpu/scan.cpp type mlx::core (line 12) | namespace mlx::core { function contiguous_scan (line 17) | void contiguous_scan( function strided_scan (line 82) | void strided_scan( function scan_op (line 157) | void scan_op( function scan_dispatch (line 194) | void scan_dispatch( FILE: mlx/backend/cpu/select.cpp type mlx::core (line 9) | namespace mlx::core { function select_op (line 14) | void select_op( FILE: mlx/backend/cpu/simd/accelerate_fp16_simd.h function namespace (line 9) | namespace mlx::core::simd { FILE: mlx/backend/cpu/simd/accelerate_simd.h function value (line 66) | value(v){} function T (line 73) | T operator[](int idx) const { FILE: mlx/backend/cpu/simd/base_simd.h function namespace (line 16) | namespace mlx::core::simd { function DEFAULT_UNARY (line 91) | DEFAULT_UNARY(operator!, std::logical_not{}) FILE: mlx/backend/cpu/simd/math.h function namespace (line 7) | namespace mlx::core::simd { function lhs (line 155) | auto lhs = [](auto t) { function rhs (line 167) | auto rhs = [](auto t) { FILE: mlx/backend/cpu/simd/neon_fp16_simd.h function namespace (line 7) | namespace mlx::core::simd { function Simd (line 160) | inline Simd isnan(Simd v) { function float16_t (line 182) | inline float16_t max(Simd x) { function float16_t (line 189) | inline float16_t min(Simd x) { function float16_t (line 196) | inline float16_t sum(Simd x) { function float16_t (line 203) | inline float16_t prod(Simd x) { FILE: mlx/backend/cpu/slicing.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/cpu/softmax.cpp type mlx::core (line 12) | namespace mlx::core { function softmax (line 19) | void softmax(const array& in, array& out, Stream stream) { FILE: mlx/backend/cpu/sort.cpp type mlx::core (line 14) | namespace mlx::core { function nan_aware_less (line 24) | bool nan_aware_less(T a, T b) { type StridedIterator (line 35) | struct StridedIterator { method StridedIterator (line 43) | StridedIterator() = default; method StridedIterator (line 45) | explicit StridedIterator(T* ptr, int64_t stride, difference_type off... method StridedIterator (line 48) | explicit StridedIterator(array& arr, int axis, difference_type offse... method reference (line 52) | reference operator*() const { method reference (line 56) | reference operator[](difference_type idx) const { method difference_type (line 85) | difference_type operator-(const StridedIterator& other) const { method StridedIterator (line 90) | StridedIterator& operator++() { method StridedIterator (line 95) | StridedIterator& operator--() { method StridedIterator (line 100) | StridedIterator& operator+=(difference_type diff) { method StridedIterator (line 105) | StridedIterator& operator-=(difference_type diff) { method StridedIterator (line 110) | StridedIterator operator+(difference_type diff) { method StridedIterator (line 114) | StridedIterator operator-(difference_type diff) { function sort (line 124) | void sort(array& out, int axis) { function argsort (line 155) | void argsort(const array& in, array& out, int axis) { function partition (line 218) | void partition(array& out, int axis, int kth) { function argpartition (line 252) | void argpartition(const array& in, array& out, int axis, int kth) { FILE: mlx/backend/cpu/svd.cpp type mlx::core (line 9) | namespace mlx::core { type SVDWork (line 12) | struct SVDWork {} type SVDWork< T, typename std::enable_if::value>::type> (line 15) | struct SVDWork< method SVDWork (line 30) | SVDWork(int N, int M, int K, char jobz) method run (line 69) | void run(T* a, R* s, T* u, T* vt) { type SVDWork> (line 99) | struct SVDWork> { method SVDWork (line 113) | SVDWork(int N, int M, int K, char jobz) method run (line 158) | void run(T* a, R* s, T* u, T* vt) { function svd_impl (line 189) | void svd_impl( FILE: mlx/backend/cpu/ternary.h function namespace (line 9) | namespace mlx::core { function ContiguousIterator (line 100) | ContiguousIterator a_it(shape, a_strides, ndim - 2); function else (line 138) | else if (topt == TernaryOpType::VectorVectorVector) { FILE: mlx/backend/cpu/threefry.cpp type mlx::core::random (line 5) | namespace mlx::core::random { function threefry2x32_hash (line 7) | std::pair threefry2x32_hash( FILE: mlx/backend/cpu/threefry.h function namespace (line 8) | namespace mlx::core::random { FILE: mlx/backend/cpu/unary.cpp type mlx::core (line 12) | namespace mlx::core { FILE: mlx/backend/cpu/unary.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/cpu/unary_ops.h function namespace (line 11) | namespace mlx::core::detail { type Square (line 103) | struct Square { type ToFP8 (line 120) | struct ToFP8 { function else (line 154) | struct FromFP8 { FILE: mlx/backend/cuda/allocator.cpp type mlx::core (line 18) | namespace mlx::core { type cu (line 20) | namespace cu { function is_windows (line 32) | bool is_windows() { function supports_managed_memory (line 54) | bool supports_managed_memory() { function unified_free (line 83) | inline void unified_free(void* data) { function cudaMemLocation (line 92) | inline cudaMemLocation cuda_mem_loc(int i) { function cuda_mem_loc (line 99) | inline int cuda_mem_loc(int i) { function CudaBuffer (line 134) | CudaBuffer* SmallSizePool::malloc() { function Buffer (line 183) | Buffer function Buffer (line 270) | Buffer CudaAllocator::malloc(size_t size) { function CudaAllocator (line 385) | CudaAllocator& allocator() { function Buffer (line 397) | Buffer malloc_async(size_t size, CommandEncoder& encoder) { type allocator (line 404) | namespace allocator { function Allocator (line 406) | Allocator& allocator() { function get_active_memory (line 421) | size_t get_active_memory() { function get_peak_memory (line 424) | size_t get_peak_memory() { function reset_peak_memory (line 427) | void reset_peak_memory() { function set_memory_limit (line 430) | size_t set_memory_limit(size_t limit) { function get_memory_limit (line 433) | size_t get_memory_limit() { function get_cache_memory (line 436) | size_t get_cache_memory() { function set_cache_limit (line 439) | size_t set_cache_limit(size_t limit) { function clear_cache (line 442) | void clear_cache() { function set_wired_limit (line 447) | size_t set_wired_limit(size_t) { FILE: mlx/backend/cuda/allocator.h type CudaBuffer (line 21) | struct CudaBuffer { function Block (line 36) | Block* next_free_{nullptr}; function class (line 50) | class CudaAllocator : public allocator::Allocator { FILE: mlx/backend/cuda/compiled.cpp type mlx::core (line 13) | namespace mlx::core { type cu (line 15) | namespace cu { type FusedKernelBuilder (line 17) | struct FusedKernelBuilder { method build (line 25) | void build(const char* name, bool contiguous) { FILE: mlx/backend/cuda/conv.cpp type mlx::core (line 14) | namespace mlx::core { type ConvBackendType (line 18) | enum ConvBackendType { type ConvCacheKey (line 25) | struct ConvCacheKey { function get_conv_settings (line 49) | auto get_conv_settings( function build_conv_graph (line 93) | std::optional build_conv_graph( function array (line 147) | array group_transpose( function prepare_args (line 185) | std::tuple prepare_args( function register_args (line 229) | void register_args( FILE: mlx/backend/cuda/conv/conv.h function namespace (line 8) | namespace mlx::core { FILE: mlx/backend/cuda/cublas_utils.cpp type mlx::core (line 7) | namespace mlx::core { type cublas_utils (line 8) | namespace cublas_utils { type CublasPreference (line 12) | struct CublasPreference { method CublasPreference (line 13) | CublasPreference(cu::Device& device) { function cublasLtMatmulPreference_t (line 38) | cublasLtMatmulPreference_t get_preference(cu::Device& device) { function cublasLtMatrixLayout_t (line 43) | cublasLtMatrixLayout_t create_matrix_layout( FILE: mlx/backend/cuda/cublas_utils.h function namespace (line 10) | namespace cublas_utils { function cublasLtMatrixLayout_t (line 63) | cublasLtMatrixLayout_t c_desc_{nullptr}; FILE: mlx/backend/cuda/cuda.h function namespace (line 11) | namespace mlx::core::cu { FILE: mlx/backend/cuda/cuda_utils.h function namespace (line 10) | namespace mlx::core { function namespace (line 67) | namespace cu { FILE: mlx/backend/cuda/cudnn_utils.cpp type mlx::core (line 6) | namespace mlx::core { function normalized_strides (line 20) | std::vector normalized_strides(const array& x) { function nhwc_to_nchw (line 39) | inline auto nhwc_to_nchw(const array& x) { FILE: mlx/backend/cuda/cudnn_utils.h function namespace (line 16) | namespace mlx::core { function cached_is_updatable_ (line 191) | bool cached_is_updatable_{true}; FILE: mlx/backend/cuda/custom_kernel.cpp type mlx::core::fast (line 15) | namespace mlx::core::fast { function template_arguments_hash (line 28) | std::string template_arguments_hash( function build_kernel (line 51) | std::string build_kernel( function CustomKernelFunction (line 144) | CustomKernelFunction cuda_kernel( function precompiled_cuda_kernel (line 244) | std::vector precompiled_cuda_kernel( FILE: mlx/backend/cuda/delayload.cpp type mlx::core (line 10) | namespace mlx::core { function relative_to_current_binary (line 14) | inline fs::path relative_to_current_binary(const char* relative) { function cublas_bin_dir (line 18) | inline fs::path cublas_bin_dir() { function load_nvrtc (line 26) | fs::path load_nvrtc() { function load_cudnn (line 38) | fs::path load_cudnn() { function FARPROC (line 61) | FARPROC WINAPI delayload_helper(unsigned dliNotify, PDelayLoadInfo pdl... FILE: mlx/backend/cuda/device.cpp type mlx::core::cu (line 14) | namespace mlx::core::cu { function use_cuda_graphs (line 18) | bool use_cuda_graphs() { function is_empty_dim (line 34) | inline bool is_empty_dim(dim3 dim) { function CommandEncoder (line 79) | CommandEncoder& Device::get_command_encoder(Stream s) { function cublasLtHandle_t (line 87) | cublasLtHandle_t Device::get_cublaslt_handle() { function cudnnHandle_t (line 95) | cudnnHandle_t Device::get_cudnn_handle() { function get_graph_limits (line 213) | std::pair get_graph_limits(Device& d) { function cudaGraphNode_t (line 371) | cudaGraphNode_t CommandEncoder::add_kernel_node_raw( function CUgraphNode (line 379) | CUgraphNode CommandEncoder::add_kernel_node_raw( function subgraph_to_key (line 387) | std::pair subgraph_to_key(cudaGraph_t graph) { function Device (line 573) | Device& device(int cuda_device) { function Device (line 588) | Device& device(mlx::core::Device d) { function CommandEncoder (line 592) | CommandEncoder& get_command_encoder(Stream s) { FILE: mlx/backend/cuda/device.h function class (line 22) | class CommandEncoder { function add_temporary (line 103) | void add_temporary(const array& arr) { type GraphNode (line 126) | struct GraphNode { function node_count_ (line 143) | int node_count_{0} function bytes_in_graph_ (line 155) | size_t bytes_in_graph_{0} function is_graph_updatable_ (line 156) | bool is_graph_updatable_{true}; function cublasLtHandle_t (line 208) | cublasLtHandle_t cublaslt_handle_{nullptr}; FILE: mlx/backend/cuda/device_info.cpp type mlx::core (line 14) | namespace mlx::core { type nvmlDevice_st (line 22) | struct nvmlDevice_st type nvmlMemory_t (line 23) | struct nvmlMemory_t { type NVMLState (line 29) | struct NVMLState { function nvml_init (line 38) | bool nvml_init(NVMLState& nvml) { function nvml_get_memory (line 66) | bool nvml_get_memory( function format_uuid (line 84) | std::string format_uuid(const cudaUUID_t& uuid) { type DeviceInfo (line 118) | struct DeviceInfo { type gpu (line 205) | namespace gpu { function is_available (line 207) | bool is_available() { function device_count (line 211) | int device_count() { type cu (line 224) | namespace cu { function is_available (line 226) | bool is_available() { FILE: mlx/backend/cuda/eval.cpp type mlx::core::gpu (line 11) | namespace mlx::core::gpu { function new_stream (line 13) | void new_stream(Stream s) { function eval (line 22) | void eval(array& arr) { function finalize (line 60) | void finalize(Stream s) { function synchronize (line 65) | void synchronize(Stream s) { FILE: mlx/backend/cuda/event.h function namespace (line 14) | namespace mlx::core::cu { function class (line 27) | class CudaEvent { FILE: mlx/backend/cuda/fence.cpp type mlx::core (line 8) | namespace mlx::core { type FenceImpl (line 10) | struct FenceImpl { FILE: mlx/backend/cuda/gemms/cublas_gemm.cpp type mlx::core (line 11) | namespace mlx::core { function cublasComputeType_t (line 15) | cublasComputeType_t dtype_to_compute_type(Dtype dtype) { FILE: mlx/backend/cuda/gemms/cublas_gemm.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/cuda/gemms/gemv.h function namespace (line 7) | namespace mlx::core::cu { FILE: mlx/backend/cuda/gemms/grouped_gemm.h function namespace (line 5) | namespace mlx::core { FILE: mlx/backend/cuda/indexing.cpp type mlx::core (line 24) | namespace mlx::core { function append_indices_arg (line 32) | void append_indices_arg( FILE: mlx/backend/cuda/jit_module.cpp type mlx::core::cu (line 16) | namespace mlx::core::cu { function check_nvrtc_error (line 22) | void check_nvrtc_error(const char* name, nvrtcResult err) { function get_ptx_path (line 123) | std::filesystem::path get_ptx_path( function read_cached_ptx (line 143) | bool read_cached_ptx( function write_cached_ptx (line 177) | void write_cached_ptx( function version_lower_equal (line 211) | inline bool version_lower_equal(Device& device, int major, int minor) { function compiler_supports_device_sass (line 222) | bool compiler_supports_device_sass(Device& device) { function compile (line 274) | void compile( function load_module (line 345) | void load_module( function CUfunction (line 436) | CUfunction JitModule::get_kernel( function JitModule (line 447) | JitModule& get_jit_module( FILE: mlx/backend/cuda/jit_module.h function namespace (line 18) | namespace mlx::core::cu { function append_ptr (line 65) | void append_ptr(const void* v) { function class (line 88) | class JitModule { FILE: mlx/backend/cuda/load.cpp function swap_endianness (line 13) | void swap_endianness(uint8_t* data_bytes, size_t N) { type mlx::core (line 29) | namespace mlx::core { FILE: mlx/backend/cuda/lru_cache.h function namespace (line 14) | namespace mlx::core { FILE: mlx/backend/cuda/matmul.cpp type mlx::core (line 14) | namespace mlx::core { function check_transpose (line 18) | std::tuple function ensure_batch_contiguous (line 33) | std::tuple function array (line 52) | array ensure_row_contiguous( function gemm_and_bias (line 65) | void gemm_and_bias( function gather_mm_rhs (line 139) | void gather_mm_rhs( FILE: mlx/backend/cuda/no_cuda.cpp type mlx::core (line 6) | namespace mlx::core { type cu (line 8) | namespace cu { function is_available (line 10) | bool is_available() { type fast (line 16) | namespace fast { function CustomKernelFunction (line 18) | CustomKernelFunction cuda_kernel( function precompiled_cuda_kernel (line 29) | std::vector precompiled_cuda_kernel( FILE: mlx/backend/cuda/primitives.cpp type mlx::core (line 8) | namespace mlx::core { type distributed (line 37) | namespace distributed { FILE: mlx/backend/cuda/quantized/cublas_qqmm.cpp type mlx::core (line 12) | namespace mlx::core { type QuantModeConfig (line 16) | struct QuantModeConfig { function QuantModeConfig (line 22) | QuantModeConfig get_quant_mode_config(const std::string& mode) { FILE: mlx/backend/cuda/quantized/cublas_qqmm.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/cuda/quantized/no_qqmm_impl.cpp type mlx::core (line 5) | namespace mlx::core { function qqmm_impl (line 6) | void qqmm_impl( FILE: mlx/backend/cuda/quantized/qmm/qmm.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/cuda/quantized/qqmm.cpp type mlx::core (line 13) | namespace mlx::core { function quantize_input (line 17) | std::tuple quantize_input( function GemmScalars (line 55) | GemmScalars create_nvfp4_scalars( FILE: mlx/backend/cuda/quantized/qqmm_impl.cpp type mlx::core (line 6) | namespace mlx::core { function qqmm_impl (line 8) | void qqmm_impl( FILE: mlx/backend/cuda/quantized/qqmm_impl.h function namespace (line 9) | namespace mlx::core { FILE: mlx/backend/cuda/quantized/qqmm_utils.h function namespace (line 8) | namespace mlx::core { function array (line 30) | inline array pad_and_swizzle_scales( FILE: mlx/backend/cuda/quantized/quantized.cpp type mlx::core (line 13) | namespace mlx::core { FILE: mlx/backend/cuda/quantized/quantized.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/cuda/quantized/quantized_utils.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/cuda/scaled_dot_product_attention.cpp type mlx::core (line 11) | namespace mlx::core { function array (line 15) | array prepare_sdpa_input(const array& x, Stream s) { function array (line 28) | array prepare_sdpa_sinks(const array& sinks, Stream s) { function malloc_with_same_layout (line 40) | void malloc_with_same_layout( function use_cudnn_for_decoding (line 72) | bool use_cudnn_for_decoding( function array (line 111) | array unslice_kv(const array& kv) { type SDPACacheKey (line 126) | struct SDPACacheKey { function build_sdpa_cache_key (line 142) | inline BytesKey build_sdpa_cache_key( type UIDS (line 190) | enum UIDS { function DnnGraph (line 208) | DnnGraph build_sdpa_graph( function DnnGraph (line 259) | DnnGraph build_sdpa_backward_graph( function supports_sdpa_cudnn (line 310) | bool supports_sdpa_cudnn( function sdpa_cudnn (line 347) | void sdpa_cudnn( function sdpa_backward_cudnn (line 446) | void sdpa_backward_cudnn( type fast (line 545) | namespace fast { FILE: mlx/backend/cuda/slicing.cpp type mlx::core (line 12) | namespace mlx::core { function concatenate_gpu (line 14) | void concatenate_gpu( function array (line 44) | array compute_dynamic_offset( FILE: mlx/backend/cuda/utils.cpp type mlx::core (line 11) | namespace mlx::core { function check_cublas_error (line 13) | void check_cublas_error(const char* name, cublasStatus_t err) { function check_cuda_error (line 21) | void check_cuda_error(const char* name, cudaError_t err) { function check_cuda_error (line 28) | void check_cuda_error(const char* name, CUresult err) { function check_cudnn_error (line 36) | void check_cudnn_error(const char* name, cudnnStatus_t err) { FILE: mlx/backend/cuda/utils.h function namespace (line 11) | namespace mlx::core { type Dtype (line 41) | struct Dtype FILE: mlx/backend/cuda/worker.cpp type mlx::core::cu (line 6) | namespace mlx::core::cu { FILE: mlx/backend/cuda/worker.h function namespace (line 13) | namespace mlx::core::cu { FILE: mlx/backend/gpu/copy.cpp type mlx::core (line 9) | namespace mlx::core { function copy_gpu (line 11) | void copy_gpu(const array& in, array& out, CopyType ctype) { function copy_gpu_inplace (line 15) | void copy_gpu_inplace( function copy_gpu_inplace (line 25) | void copy_gpu_inplace( function array (line 37) | array contiguous_copy_gpu(const array& arr, const Stream& s) { function array (line 43) | array flatten_in_eval(const array& x, int start_axis, int end_axis, St... function array (line 57) | array reshape_in_eval(const array& x, Shape shape, Stream s) { function array (line 63) | array transpose_in_eval(const array& x, const std::vector& axes) { function array (line 84) | array swapaxes_in_eval(const array& x, int axis1, int axis2) { FILE: mlx/backend/gpu/copy.h function namespace (line 11) | namespace mlx::core { FILE: mlx/backend/gpu/device_info.h function namespace (line 11) | namespace mlx::core::gpu { FILE: mlx/backend/gpu/eval.h function namespace (line 11) | namespace mlx::core::gpu { FILE: mlx/backend/gpu/primitives.cpp type mlx::core (line 21) | namespace mlx::core { FILE: mlx/backend/gpu/scan.h function namespace (line 6) | namespace mlx::core { FILE: mlx/backend/gpu/slicing.cpp type mlx::core (line 7) | namespace mlx::core { function slice_gpu (line 9) | void slice_gpu( function pad_gpu (line 18) | void pad_gpu( FILE: mlx/backend/gpu/slicing.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/metal/allocator.cpp type mlx::core (line 12) | namespace mlx::core { type allocator (line 17) | namespace allocator { function Allocator (line 19) | Allocator& allocator() { type metal (line 32) | namespace metal { function Buffer (line 104) | Buffer MetalAllocator::malloc(size_t size) { function Buffer (line 208) | Buffer MetalAllocator::make_buffer(void* ptr, size_t size) { function MetalAllocator (line 235) | MetalAllocator& allocator() { function set_cache_limit (line 245) | size_t set_cache_limit(size_t limit) { function set_memory_limit (line 248) | size_t set_memory_limit(size_t limit) { function get_memory_limit (line 251) | size_t get_memory_limit() { function set_wired_limit (line 254) | size_t set_wired_limit(size_t limit) { function get_active_memory (line 263) | size_t get_active_memory() { function get_peak_memory (line 266) | size_t get_peak_memory() { function reset_peak_memory (line 269) | void reset_peak_memory() { function get_cache_memory (line 272) | size_t get_cache_memory() { function clear_cache (line 275) | void clear_cache() { FILE: mlx/backend/metal/allocator.h function namespace (line 14) | namespace mlx::core::metal { FILE: mlx/backend/metal/binary.cpp type mlx::core (line 19) | namespace mlx::core { function get_kernel_name (line 21) | std::string get_kernel_name( function binary_op_gpu_inplace (line 65) | void binary_op_gpu_inplace( function binary_op_gpu (line 165) | void binary_op_gpu( function binary_op_gpu (line 179) | void binary_op_gpu( function binary_op_gpu_inplace (line 187) | void binary_op_gpu_inplace( function binary_op_gpu (line 196) | void binary_op_gpu( function binary_op_gpu (line 209) | void binary_op_gpu( FILE: mlx/backend/metal/binary.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/metal/compiled.cpp type mlx::core (line 14) | namespace mlx::core { function build_kernel (line 16) | inline void build_kernel( FILE: mlx/backend/metal/conv.cpp type mlx::core (line 19) | namespace mlx::core { function array (line 23) | inline array function explicit_gemm_conv_ND_gpu (line 34) | void explicit_gemm_conv_ND_gpu( function explicit_gemm_conv_group_ND_gpu (line 105) | void explicit_gemm_conv_group_ND_gpu( function implicit_gemm_conv_2D_gpu (line 191) | void implicit_gemm_conv_2D_gpu( function implicit_gemm_conv_2D_general_gpu (line 324) | void implicit_gemm_conv_2D_general_gpu( function implicit_gemm_conv_3D_gpu (line 503) | void implicit_gemm_conv_3D_gpu( function pad_and_slice_conv_3D_gpu (line 624) | void pad_and_slice_conv_3D_gpu( function dispatch_conv_3D_gpu (line 671) | void dispatch_conv_3D_gpu( function winograd_conv_2D_gpu (line 714) | void winograd_conv_2D_gpu( function depthwise_conv_2D_gpu (line 908) | void depthwise_conv_2D_gpu( function dispatch_conv_2D_gpu (line 970) | void dispatch_conv_2D_gpu( function depthwise_conv_1D_gpu (line 1032) | void depthwise_conv_1D_gpu( function conv_1D_gpu (line 1078) | void conv_1D_gpu( function conv_2D_gpu (line 1165) | void conv_2D_gpu( function conv_3D_gpu (line 1210) | void conv_3D_gpu( FILE: mlx/backend/metal/copy.cpp type mlx::core (line 9) | namespace mlx::core { function copy_gpu (line 13) | void copy_gpu(const array& in, array& out, CopyType ctype, const Strea... function copy_gpu_inplace (line 26) | void copy_gpu_inplace( function fill_gpu (line 182) | void fill_gpu(const array& val, array& out, const Stream& s) { function reshape_gpu (line 216) | void reshape_gpu(const array& in, array& out, Stream s) { FILE: mlx/backend/metal/custom_kernel.cpp type mlx::core::fast (line 14) | namespace mlx::core::fast { type CustomKernelCache (line 16) | struct CustomKernelCache { function CustomKernelCache (line 20) | static CustomKernelCache& cache() { function write_signature (line 25) | std::string write_signature( function write_template (line 153) | std::string write_template( function CustomKernelFunction (line 175) | CustomKernelFunction metal_kernel( FILE: mlx/backend/metal/device.cpp type std (line 16) | namespace std { type hash> (line 20) | struct hash> { type mlx::core::metal (line 28) | namespace mlx::core::metal { function get_metal_version (line 34) | auto get_metal_version() { function load_device (line 48) | auto load_device() { function load_library_from_path (line 57) | std::pair load_library_from_path( function load_colocated_library (line 103) | std::pair load_colocated_library( function load_swiftpm_library (line 114) | std::pair load_swiftpm_library( function CommandEncoder (line 523) | CommandEncoder& Device::get_command_encoder(int index) { function Device (line 832) | Device& device(mlx::core::Device) { function new_scoped_memory_pool (line 840) | std::unique_ptr> new_scoped_memory_po... FILE: mlx/backend/metal/device.h function namespace (line 16) | namespace mlx::core::metal { function set_threadgroup_memory_length (line 80) | void set_threadgroup_memory_length(size_t length, int idx) { function ConcurrentContext (line 84) | ConcurrentContext start_concurrent() { function needs_commit (line 90) | bool needs_commit() const; function buffer_ops_ (line 108) | int buffer_ops_{0} function buffer_sizes_ (line 109) | size_t buffer_sizes_{0} function needs_barrier_ (line 116) | bool needs_barrier_{false}; FILE: mlx/backend/metal/device_info.cpp type mlx::core::gpu (line 9) | namespace mlx::core::gpu { function is_available (line 11) | bool is_available() { function device_count (line 15) | int device_count() { FILE: mlx/backend/metal/distributed.cpp type mlx::core::distributed (line 15) | namespace mlx::core::distributed { FILE: mlx/backend/metal/eval.cpp type mlx::core::gpu (line 10) | namespace mlx::core::gpu { function new_stream (line 12) | void new_stream(Stream stream) { function check_error (line 18) | inline void check_error(MTL::CommandBuffer* cbuf) { function eval (line 27) | void eval(array& arr) { function finalize (line 74) | void finalize(Stream s) { function synchronize (line 83) | void synchronize(Stream s) { FILE: mlx/backend/metal/event.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/metal/fence.cpp type mlx::core (line 7) | namespace mlx::core { type FenceImpl (line 9) | struct FenceImpl { method FenceImpl (line 10) | FenceImpl() { FILE: mlx/backend/metal/fft.cpp type mlx::core (line 18) | namespace mlx::core { function supported_radices (line 30) | inline const std::vector supported_radices() { function prime_factors (line 35) | std::vector prime_factors(int n) { type FourStepParams (line 52) | struct FourStepParams { type FFTPlan (line 70) | struct FFTPlan { function next_fast_n (line 85) | int next_fast_n(int n) { function plan_stockham_fft (line 89) | std::vector plan_stockham_fft(int n) { function FFTPlan (line 113) | FFTPlan plan_fft(int n) { function compute_elems_per_thread (line 174) | int compute_elems_per_thread(FFTPlan plan) { function mod_exp (line 231) | int mod_exp(int x, int y, int n) { function primitive_root (line 243) | int primitive_root(int n) { function compute_raders_constants (line 261) | std::tuple compute_raders_constants( function compute_bluestein_constants (line 303) | std::pair compute_bluestein_constants(int n, int blueste... function multi_upload_bluestein_fft (line 349) | void multi_upload_bluestein_fft( function four_step_fft (line 474) | void four_step_fft( function fft_op (line 508) | void fft_op( function fft_op (line 751) | void fft_op( function nd_fft_op (line 762) | void nd_fft_op( FILE: mlx/backend/metal/hadamard.cpp type mlx::core (line 13) | namespace mlx::core { function gen_hadamard_codelet (line 17) | std::string gen_hadamard_codelet(int m) { function hadamard_mn_contiguous (line 60) | void hadamard_mn_contiguous( FILE: mlx/backend/metal/indexing.cpp type mlx::core (line 20) | namespace mlx::core { function make_index_args (line 24) | std::pair make_index_args( function make_op (line 42) | inline std::string make_op(typename T::ReduceType r, const std::string... FILE: mlx/backend/metal/jit/includes.h function namespace (line 5) | namespace mlx::core::metal { FILE: mlx/backend/metal/jit/indexing.h function std (line 3) | constexpr std::string_view gather_kernels = R"( function std (line 36) | constexpr std::string_view scatter_kernels = R"( FILE: mlx/backend/metal/jit_kernels.cpp type mlx::core (line 9) | namespace mlx::core { function append_binary_kernels (line 54) | void append_binary_kernels( FILE: mlx/backend/metal/kernels.h function namespace (line 10) | namespace mlx::core { FILE: mlx/backend/metal/kernels/arange.h function arange (line 3) | [[kernel]] void arange( FILE: mlx/backend/metal/kernels/atomic.h function uint (line 162) | uint operator()(uint_or_packed init, T update, size_t elem_offset) { function mlx_atomic_update_and_store (line 170) | void mlx_atomic_update_and_store( function condition (line 193) | static bool condition(T a, T b) { function T (line 199) | T operator()(T a, T b) { function condition (line 207) | static bool condition(T a, T b) { function T (line 213) | T operator()(T a, T b) { function condition (line 220) | static bool condition(T a, T b) { function T (line 225) | T operator()(T a, T b) { function condition (line 232) | static bool condition(T a, T b) { function T (line 236) | T operator()(T a, T b) { function condition (line 243) | static bool condition(T a, T b) { function T (line 247) | T operator()(T a, T b) { FILE: mlx/backend/metal/kernels/bf16.h type bfloat (line 9) | typedef bfloat bfloat16_t; function bfloat16_to_uint16 (line 10) | inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { function bfloat16_t (line 14) | inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { FILE: mlx/backend/metal/kernels/bf16_math.h function namespace (line 226) | namespace metal { function namespace (line 370) | namespace metal { FILE: mlx/backend/metal/kernels/binary.h function binary_ss (line 4) | [[kernel]] void binary_ss( function binary_sv (line 13) | void binary_sv( function binary_vs (line 32) | void binary_vs( function binary_vv (line 51) | void binary_vv( function binary_sv2 (line 70) | void binary_sv2( function binary_vs2 (line 90) | void binary_vs2( function binary_vv2 (line 110) | void binary_vv2( function binary_g_nd1 (line 130) | void binary_g_nd1( function binary_g_nd2 (line 143) | void binary_g_nd2( function binary_g_nd3 (line 158) | void binary_g_nd3( FILE: mlx/backend/metal/kernels/binary_ops.h type Add (line 10) | struct Add { type FloorDivide (line 17) | struct FloorDivide { type Divide (line 36) | struct Divide { type Remainder (line 43) | struct Remainder { type Equal (line 72) | struct Equal { type NaNEqual (line 79) | struct NaNEqual { type Greater (line 94) | struct Greater { type GreaterEqual (line 101) | struct GreaterEqual { type Less (line 108) | struct Less { type LessEqual (line 115) | struct LessEqual { type LogAddExp (line 122) | struct LogAddExp { function complex64_t (line 136) | complex64_t operator()(complex64_t x, complex64_t y) { type Maximum (line 155) | struct Maximum { type Minimum (line 178) | struct Minimum { type Multiply (line 201) | struct Multiply { type NotEqual (line 208) | struct NotEqual { type Power (line 219) | struct Power { type Subtract (line 262) | struct Subtract { type LogicalAnd (line 269) | struct LogicalAnd { type LogicalOr (line 276) | struct LogicalOr { type BitwiseAnd (line 283) | struct BitwiseAnd { type BitwiseOr (line 290) | struct BitwiseOr { type BitwiseXor (line 297) | struct BitwiseXor { type LeftShift (line 304) | struct LeftShift { type RightShift (line 311) | struct RightShift { type ArcTan2 (line 318) | struct ArcTan2 { type DivMod (line 325) | struct DivMod { FILE: mlx/backend/metal/kernels/binary_two.h function binary_ss (line 4) | [[kernel]] void binary_ss( function binary_sv (line 16) | void binary_sv( function binary_vs (line 40) | void binary_vs( function binary_vv (line 64) | void binary_vv( function binary_sv2 (line 88) | void binary_sv2( function binary_vs2 (line 113) | void binary_vs2( function binary_vv2 (line 138) | void binary_vv2( function binary_g_nd1 (line 163) | void binary_g_nd1( function binary_g_nd2 (line 179) | void binary_g_nd2( function binary_g_nd3 (line 197) | void binary_g_nd3( FILE: mlx/backend/metal/kernels/cexpf.h function get_float_word (line 32) | inline void get_float_word(thread uint32_t& i, float d) { function get_float_word (line 38) | inline void get_float_word(thread int32_t& i, float d) { function set_float_word (line 44) | inline void set_float_word(thread float& d, uint32_t i) { function frexp_expf (line 50) | inline float frexp_expf(float x, thread int* expt) { FILE: mlx/backend/metal/kernels/complex.h type complex64_t (line 9) | struct complex64_t function complex64_t (line 145) | constexpr complex64_t operator*(complex64_t a, complex64_t b) { FILE: mlx/backend/metal/kernels/copy.h function copy_s (line 4) | void copy_s( function copy_v (line 22) | void copy_v( function copy_s2 (line 40) | void copy_s2( function copy_v2 (line 59) | void copy_v2( function copy_g_nd1 (line 78) | void copy_g_nd1( function copy_g_nd2 (line 88) | void copy_g_nd2( function copy_g_nd3 (line 100) | void copy_g_nd3( function copy_g (line 113) | void copy_g( function copy_gg_nd2 (line 151) | void copy_gg_nd2( function copy_gg_nd3 (line 163) | void copy_gg_nd3( function copy_gg (line 175) | void copy_gg( function copy_gg_dynamic_nd2 (line 218) | void copy_gg_dynamic_nd2( function copy_gg_dynamic_nd3 (line 232) | void copy_gg_dynamic_nd3( FILE: mlx/backend/metal/kernels/erf.h function erf (line 12) | float erf(float a) { function erfinv (line 42) | float erfinv(float a) { FILE: mlx/backend/metal/kernels/expm1f.h function expm1f_scaled_unchecked (line 43) | float expm1f_scaled_unchecked(float a, float b) { function expm1f (line 80) | float expm1f(float a) { FILE: mlx/backend/metal/kernels/fft.h function fft (line 180) | [[kernel]] void fft( function rader_fft (line 219) | [[kernel]] void rader_fft( function bluestein_fft (line 374) | [[kernel]] void bluestein_fft( function four_step_fft (line 443) | void four_step_fft( FILE: mlx/backend/metal/kernels/fft/radix.h function METAL_FUNC (line 19) | METAL_FUNC float2 complex_mul(float2 a, float2 b) { function METAL_FUNC (line 24) | METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { function METAL_FUNC (line 29) | METAL_FUNC float2 get_twiddle(int k, int p) { function METAL_FUNC (line 36) | METAL_FUNC void radix2(thread float2* x, thread float2* y) { function METAL_FUNC (line 41) | METAL_FUNC void radix3(thread float2* x, thread float2* y) { function METAL_FUNC (line 56) | METAL_FUNC void radix4(thread float2* x, thread float2* y) { function METAL_FUNC (line 69) | METAL_FUNC void radix5(thread float2* x, thread float2* y) { function METAL_FUNC (line 96) | METAL_FUNC void radix6(thread float2* x, thread float2* y) { function METAL_FUNC (line 122) | METAL_FUNC void radix7(thread float2* x, thread float2* y) { function METAL_FUNC (line 151) | METAL_FUNC void radix8(thread float2* x, thread float2* y) { function METAL_FUNC (line 201) | METAL_FUNC void radix11(thread float2* x, thread float2* y) { function METAL_FUNC (line 290) | METAL_FUNC void radix13(thread float2* x, thread float2* y) { FILE: mlx/backend/metal/kernels/fft/readwrite.h function METAL_FUNC (line 77) | METAL_FUNC float2 post_in(float2 elem) const { function METAL_FUNC (line 82) | METAL_FUNC float2 post_in(float elem) const { function METAL_FUNC (line 86) | METAL_FUNC float2 pre_out(float2 elem) const { function METAL_FUNC (line 90) | METAL_FUNC float2 pre_out(float2 elem, int length) const { function METAL_FUNC (line 94) | METAL_FUNC bool out_of_bounds() const { function METAL_FUNC (line 123) | METAL_FUNC void write() const { function METAL_FUNC (line 146) | METAL_FUNC void load_padded(int length, const device float2* w_k) const { function METAL_FUNC (line 163) | METAL_FUNC void write_padded(int length, const device float2* w_k) const { function METAL_FUNC (line 180) | METAL_FUNC void compute_strided_indices(int stride, int overall_n) { function METAL_FUNC (line 202) | METAL_FUNC void load_strided(int stride, int overall_n) { function METAL_FUNC (line 210) | METAL_FUNC void write_strided(int stride, int overall_n) { function write_padded (line 505) | float>::write_padded( FILE: mlx/backend/metal/kernels/fp4.h type fp4_e2m1 (line 3) | struct fp4_e2m1 { function operator (line 33) | operator float16_t() { function operator (line 39) | operator float() { function operator (line 43) | operator bfloat16_t() { FILE: mlx/backend/metal/kernels/fp8.h function else (line 3) | struct fp8_e4m3 { function operator (line 32) | operator float16_t() { function operator (line 40) | operator bfloat16_t() { function operator (line 44) | operator float() { function operator (line 69) | operator bfloat16_t() { function operator (line 74) | operator float() { FILE: mlx/backend/metal/kernels/fp_quantized.h function get_pack_factor (line 21) | short get_pack_factor() { function get_bytes_per_pack (line 26) | short get_bytes_per_pack() { function U (line 53) | U operator()(uint8_t x) { function load_vector (line 63) | inline void load_vector(const device T* x, thread U* x_thread) { function load_vector_safe (line 71) | inline void load_vector_safe(const device T* x, thread U* x_thread, int ... function U (line 82) | inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { function U (line 103) | inline U function qouter (line 125) | inline void qouter(const thread uint8_t* w, U x, U scale, thread U* resu... function load_safe (line 218) | void load_safe(short2 src_tile_dim) const { function next (line 244) | void next() { type U (line 283) | typedef float U; type U (line 346) | typedef float U; type U (line 408) | typedef float U; type U (line 549) | typedef float U; type vec_w (line 550) | typedef struct { function fp_qmm_t_impl (line 635) | void fp_qmm_t_impl( function fp_qmm_n_impl (line 759) | void fp_qmm_n_impl( function fp_qmv_quad (line 972) | void fp_qmv_quad( function fp_qmv_fast (line 1011) | void fp_qmv_fast( function fp_qmv (line 1050) | void fp_qmv( function fp_qvm (line 1089) | void fp_qvm( function fp_qmm_t (line 1179) | void fp_qmm_t( function fp_qmm_n (line 1233) | void fp_qmm_n( function fp_gather_qmv_fast (line 1282) | void fp_gather_qmv_fast( function fp_gather_qmv (line 1331) | void fp_gather_qmv( function fp_gather_qvm (line 1380) | void fp_gather_qvm( function fp_gather_qmm_t (line 1436) | void fp_gather_qmm_t( function fp_gather_qmm_n (line 1499) | void fp_gather_qmm_n( function fp_gather_qmm_rhs (line 1566) | void fp_gather_qmm_rhs( function fp_quantize (line 1750) | void fp_quantize( function fp_dequantize (line 1792) | void fp_dequantize( function fp_quantize_dequantize (line 1824) | void fp_quantize_dequantize( FILE: mlx/backend/metal/kernels/fp_quantized_nax.h function get_pack_factor (line 21) | short get_pack_factor() { function get_bytes_per_pack (line 26) | short get_bytes_per_pack() { function U (line 53) | U operator()(uint8_t x) { function load_safe (line 146) | void load_safe(short2 src_tile_dim) const { function next (line 176) | void next() { function fp_qmm_t_impl (line 199) | void fp_qmm_t_impl( function fp_qmm_n_impl (line 343) | void fp_qmm_n_impl( function fp_qmm_t_nax (line 550) | void fp_qmm_t_nax( function fp_qmm_n_nax (line 606) | void fp_qmm_n_nax( function fp_gather_qmm_t_nax (line 665) | void fp_gather_qmm_t_nax( function fp_gather_qmm_n_nax (line 730) | void fp_gather_qmm_n_nax( function fp_gather_qmm_rhs_nax (line 798) | void fp_gather_qmm_rhs_nax( FILE: mlx/backend/metal/kernels/gemv_masked.h type _NoMask (line 10) | struct _NoMask { type nomask_t (line 27) | typedef struct _NoMask nomask_t; function METAL_FUNC (line 33) | METAL_FUNC OutT apply(InT x) const { function METAL_FUNC (line 125) | static METAL_FUNC void run( function METAL_FUNC (line 404) | static METAL_FUNC void run( FILE: mlx/backend/metal/kernels/hadamard.h function hadamard_m (line 142) | void hadamard_m( FILE: mlx/backend/metal/kernels/indexing/gather_axis.h function gather_axis (line 6) | void gather_axis( FILE: mlx/backend/metal/kernels/indexing/gather_front.h function gather_front (line 8) | void gather_front( FILE: mlx/backend/metal/kernels/indexing/masked_scatter.h function masked_assign_impl (line 8) | [[kernel]] void masked_assign_impl( FILE: mlx/backend/metal/kernels/indexing/scatter.h function slice_update_op_impl (line 70) | void slice_update_op_impl( FILE: mlx/backend/metal/kernels/indexing/scatter_axis.h function scatter_axis (line 12) | [[kernel]] void scatter_axis( FILE: mlx/backend/metal/kernels/logging.h function namespace (line 8) | namespace mlx { type os_log (line 15) | struct os_log { FILE: mlx/backend/metal/kernels/logsumexp.h function logsumexp (line 4) | void logsumexp( function logsumexp_looped (line 76) | void logsumexp_looped( FILE: mlx/backend/metal/kernels/quantized.h function get_pack_factor (line 18) | short get_pack_factor() { function get_bytes_per_pack (line 23) | short get_bytes_per_pack() { function U (line 29) | inline U load_vector(const device T* x, thread U* x_thread) { function U (line 108) | inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { function U (line 192) | inline U qdot( function U (line 293) | inline U qdot_safe( function qouter (line 395) | inline void function dequantize (line 484) | inline void function load_safe (line 641) | void load_safe(short2 src_tile_dim) const { function next (line 671) | void next() { type U (line 711) | typedef float U; type U (line 772) | typedef float U; type U (line 840) | typedef float U; type U (line 1001) | typedef float U; type vec_w (line 1002) | typedef struct { function qmm_t_impl (line 1094) | void qmm_t_impl( function qmm_n_impl (line 1220) | void qmm_n_impl( function affine_qmv_quad (line 1443) | void affine_qmv_quad( function affine_qmv_fast (line 1495) | void affine_qmv_fast( function affine_qmv (line 1547) | void affine_qmv( function affine_qvm (line 1599) | void affine_qvm( function affine_qmm_t (line 1715) | void affine_qmm_t( function affine_qmm_n (line 1773) | void affine_qmm_n( function affine_gather_qmv_fast (line 1826) | void affine_gather_qmv_fast( function affine_gather_qmv (line 1888) | void affine_gather_qmv( function affine_gather_qvm (line 1950) | void affine_gather_qvm( function affine_gather_qmm_t (line 2019) | void affine_gather_qmm_t( function affine_gather_qmm_n (line 2086) | void affine_gather_qmm_n( function affine_gather_qmm_rhs (line 2157) | void affine_gather_qmm_rhs( function affine_quantize (line 2344) | void affine_quantize( function affine_dequantize (line 2449) | void affine_dequantize( FILE: mlx/backend/metal/kernels/quantized_nax.h function get_pack_factor (line 21) | short get_pack_factor() { function get_bytes_per_pack (line 26) | short get_bytes_per_pack() { function U (line 32) | inline U load_vector(const device T* x, thread U* x_thread) { function U (line 111) | inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { function U (line 195) | inline U qdot( function U (line 296) | inline U qdot_safe( function qouter (line 398) | inline void function dequantize (line 487) | inline void function load_safe (line 644) | void load_safe(short2 src_tile_dim) const { function next (line 674) | void next() { function load_safe (line 784) | void load_safe(short2 src_tile_dim) const { function next (line 814) | void next() { function qmm_t_nax_tgp_impl (line 938) | void qmm_t_nax_tgp_impl( function qmm_n_nax_tgp_impl (line 1082) | void qmm_n_nax_tgp_impl( function affine_qmm_t_nax (line 1205) | void affine_qmm_t_nax( function affine_qmm_n_nax (line 1264) | void affine_qmm_n_nax( function affine_gather_qmm_t_nax (line 1324) | void affine_gather_qmm_t_nax( function affine_gather_qmm_n_nax (line 1392) | void affine_gather_qmm_n_nax( function affine_gather_qmm_rhs_nax (line 1461) | void affine_gather_qmm_rhs_nax( FILE: mlx/backend/metal/kernels/quantized_utils.h function typename (line 77) | typename loader_b_t> FILE: mlx/backend/metal/kernels/reduction/ops.h type None (line 29) | struct None { function simd_reduce_impl (line 40) | bool simd_reduce_impl(bool val) { function update (line 67) | void update(device bool* out, bool val) { function simd_reduce_impl (line 81) | bool simd_reduce_impl(bool val) { function update (line 108) | void update(device bool* out, bool val) { function U (line 135) | U operator()(U a, U b) { function U (line 157) | U operator()(U a, U b) { FILE: mlx/backend/metal/kernels/reduction/reduce_all.h function all_reduce (line 9) | void all_reduce( FILE: mlx/backend/metal/kernels/reduction/reduce_col.h function col_reduce_small (line 4) | void col_reduce_small( function col_reduce_longcolumn (line 97) | void col_reduce_longcolumn( function col_reduce_looped (line 163) | void col_reduce_looped( function col_reduce_2pass (line 302) | void col_reduce_2pass( FILE: mlx/backend/metal/kernels/reduction/reduce_init.h function init_reduce (line 4) | [[kernel]] void init_reduce( FILE: mlx/backend/metal/kernels/reduction/reduce_row.h function per_thread_row_reduce (line 19) | void per_thread_row_reduce( function per_thread_row_reduce (line 70) | void per_thread_row_reduce( function per_thread_row_reduce (line 98) | void per_thread_row_reduce( function threadgroup_reduce (line 129) | void threadgroup_reduce( function thread_reduce (line 165) | void function row_reduce_small (line 199) | void row_reduce_small( FILE: mlx/backend/metal/kernels/scan.h function U (line 45) | U simd_scan_impl(U x) { function U (line 49) | U simd_exclusive_scan_impl(U x) { function U (line 66) | U simd_scan_impl(U x) { function U (line 70) | U simd_exclusive_scan_impl(U x) { function bool (line 76) | struct CumProd { function U (line 107) | U simd_scan(U x) { function U (line 115) | U simd_exclusive_scan(U x) { function U (line 130) | U simd_scan(U x) { function U (line 138) | U simd_exclusive_scan(U x) { function U (line 153) | U simd_scan(U x) { function U (line 161) | U simd_exclusive_scan(U x) { function load_unsafe (line 168) | inline void load_unsafe(U values[N_READS], const device T* input) { function load_safe (line 181) | inline void load_safe( function write_unsafe (line 200) | inline void write_unsafe(U values[N_READS], device U* out) { function write_safe (line 213) | inline void write_safe(U values[N_READS], device U* out, int start, int ... function contiguous_scan (line 236) | void contiguous_scan( function strided_scan (line 392) | void strided_scan( FILE: mlx/backend/metal/kernels/sdpa_vector.h function sdpa_vector (line 16) | void sdpa_vector( function sdpa_vector_2pass_1 (line 180) | void sdpa_vector_2pass_1( function sdpa_vector_2pass_2 (line 321) | [[kernel]] void sdpa_vector_2pass_2( FILE: mlx/backend/metal/kernels/softmax.h function softmax_single_row (line 11) | void softmax_single_row( function softmax_looped (line 101) | void softmax_looped( FILE: mlx/backend/metal/kernels/sort.h function METAL_FUNC (line 35) | METAL_FUNC bool operator()(T a, T b) const { type ThreadSort (line 54) | struct ThreadSort { function METAL_FUNC (line 88) | static METAL_FUNC int merge_partition( function METAL_FUNC (line 114) | static METAL_FUNC void merge_step( function METAL_FUNC (line 146) | static METAL_FUNC void sort( function METAL_FUNC (line 258) | static METAL_FUNC void block_sort( function block_sort (line 307) | void block_sort( function block_sort_nc (line 362) | void block_sort_nc( function METAL_FUNC (line 434) | static METAL_FUNC void block_sort( function METAL_FUNC (line 472) | static METAL_FUNC int merge_partition( function mb_block_sort (line 505) | void mb_block_sort( function mb_block_partition (line 549) | void mb_block_partition( FILE: mlx/backend/metal/kernels/steel/attn/attn.h function namespace (line 18) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h type MaxOp (line 18) | struct MaxOp { type SumOp (line 25) | struct SumOp { type MulOp (line 32) | struct MulOp { type SubOp (line 39) | struct SubOp { type ExpSubOp (line 46) | struct ExpSubOp { type DivOp (line 53) | struct DivOp { function ulong3 (line 88) | ulong3 tidl{tid.x, tid.y, tid.z}; FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h function METAL_FUNC (line 24) | METAL_FUNC TransformScale(T scale_) : scale(scale_) {} function METAL_FUNC (line 26) | METAL_FUNC T apply(T x) const { type MaxOp (line 31) | struct MaxOp { type SumOp (line 38) | struct SumOp { type MulOp (line 45) | struct MulOp { type SubOp (line 52) | struct SubOp { type ExpSubOp (line 59) | struct ExpSubOp { type DivOp (line 66) | struct DivOp { function ulong3 (line 102) | ulong3 tidl{tid.x, tid.y, tid.z}; FILE: mlx/backend/metal/kernels/steel/attn/loader.h function namespace (line 11) | namespace mlx { function METAL_FUNC (line 199) | METAL_FUNC void load_unsafe() const { function METAL_FUNC (line 210) | METAL_FUNC void load_safe(short2 src_tile_dim) const { function METAL_FUNC (line 258) | METAL_FUNC void next() { FILE: mlx/backend/metal/kernels/steel/attn/mma.h function namespace (line 19) | namespace mlx { function METAL_FUNC (line 513) | METAL_FUNC BlockMMA( function METAL_FUNC (line 533) | METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { function METAL_FUNC (line 560) | METAL_FUNC void store_result(device U* D, const int ldd) { function METAL_FUNC (line 573) | METAL_FUNC void function METAL_FUNC (line 675) | METAL_FUNC void store_result( function METAL_FUNC (line 707) | METAL_FUNC void store_result_safe( FILE: mlx/backend/metal/kernels/steel/attn/nax.h function namespace (line 20) | namespace mlx { function thread (line 356) | thread T* reduced_vals) { function thread (line 376) | thread T* row_vals) { type typename (line 560) | typedef typename NAXFrag_t::template dtype_frag_t frag_type; function METAL_FUNC (line 564) | METAL_FUNC NAXTile() thread {} function METAL_FUNC (line 566) | METAL_FUNC constexpr void clear() { function ta (line 844) | constexpr auto ta = metal::bool_constant{} function tb (line 845) | constexpr auto tb = metal::bool_constant{} FILE: mlx/backend/metal/kernels/steel/attn/params.h function namespace (line 9) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/attn/transforms.h function namespace (line 11) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h function implicit_gemm_conv_2d_general (line 16) | void FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h function namespace (line 13) | namespace mlx { function METAL_FUNC (line 290) | METAL_FUNC void load_unsafe() const { function METAL_FUNC (line 315) | METAL_FUNC void next() { function METAL_FUNC (line 438) | METAL_FUNC void next() { function METAL_FUNC (line 562) | METAL_FUNC void load_unsafe() const { function METAL_FUNC (line 591) | METAL_FUNC void next() { function METAL_FUNC (line 780) | METAL_FUNC void load_unsafe() const { function METAL_FUNC (line 807) | METAL_FUNC void next() { function METAL_FUNC (line 941) | METAL_FUNC void next() { FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h function namespace (line 13) | namespace mlx { function METAL_FUNC (line 190) | METAL_FUNC void next() { function METAL_FUNC (line 313) | METAL_FUNC void next() { FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h function namespace (line 11) | namespace mlx { function METAL_FUNC (line 311) | METAL_FUNC void load_safe(const short remaining_k) const { function METAL_FUNC (line 361) | METAL_FUNC void next() { FILE: mlx/backend/metal/kernels/steel/conv/params.h function MLXConvParams (line 23) | static MLXConvParams function namespace (line 48) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/gemm/gemm.h function namespace (line 17) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/gemm/gemm_nax.h function namespace (line 12) | namespace mlx::steel { FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h type _NoMask (line 11) | struct _NoMask { function METAL_FUNC (line 32) | METAL_FUNC OutT apply(InT x) const { type nomask_t (line 37) | typedef struct _NoMask nomask_t; function block_masked_gemm (line 52) | void FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h function gemm_splitk (line 21) | void gemm_splitk( function gemm_splitk_accum (line 172) | [[kernel]] void gemm_splitk_accum( function gemm_splitk_accum_axpby (line 199) | [[kernel]] void gemm_splitk_accum_axpby( FILE: mlx/backend/metal/kernels/steel/gemm/loader.h function namespace (line 11) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/gemm/mma.h function namespace (line 19) | namespace mlx { function METAL_FUNC (line 181) | METAL_FUNC static constexpr void mma( function METAL_FUNC (line 200) | METAL_FUNC static constexpr void mma( type typename (line 230) | typedef typename MMAFrag_t::mat_type mat_type; type typename (line 231) | typedef typename MMAFrag_t::frag_type frag_type; function METAL_FUNC (line 235) | METAL_FUNC MMATile() thread {} function METAL_FUNC (line 237) | METAL_FUNC constexpr void clear() { function METAL_FUNC (line 254) | METAL_FUNC mat_type mat_at(const short i, const short j) { function elem_type (line 263) | elem_type* elems() { function elem_type (line 267) | elem_type* elems() const { function METAL_FUNC (line 426) | static METAL_FUNC complex64_t apply(complex64_t x) { function METAL_FUNC (line 429) | static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { function METAL_FUNC (line 488) | METAL_FUNC BlockMMA( function METAL_FUNC (line 508) | METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { function METAL_FUNC (line 535) | METAL_FUNC void store_result(device U* D, const int ldd) { function METAL_FUNC (line 548) | METAL_FUNC void function METAL_FUNC (line 568) | METAL_FUNC void function METAL_FUNC (line 670) | METAL_FUNC void store_result( function METAL_FUNC (line 702) | METAL_FUNC void store_result_safe( function METAL_FUNC (line 820) | METAL_FUNC BlockMMA( function METAL_FUNC (line 840) | METAL_FUNC void mma( function METAL_FUNC (line 900) | METAL_FUNC void store_result(device U* D, const int ldd) { function METAL_FUNC (line 919) | METAL_FUNC void function METAL_FUNC (line 950) | METAL_FUNC void function METAL_FUNC (line 1068) | METAL_FUNC void store_result( function METAL_FUNC (line 1102) | METAL_FUNC void store_result_safe( FILE: mlx/backend/metal/kernels/steel/gemm/nax.h function namespace (line 20) | namespace mlx { function thread (line 356) | thread T* reduced_vals) { function thread (line 376) | thread T* row_vals) { type typename (line 560) | typedef typename NAXFrag_t::template dtype_frag_t frag_type; function METAL_FUNC (line 564) | METAL_FUNC NAXTile() thread {} function METAL_FUNC (line 566) | METAL_FUNC constexpr void clear() { function ta (line 844) | constexpr auto ta = metal::bool_constant{} function tb (line 845) | constexpr auto tb = metal::bool_constant{} FILE: mlx/backend/metal/kernels/steel/gemm/params.h function namespace (line 9) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/gemm/transforms.h function namespace (line 11) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/utils.h function METAL_FUNC (line 7) | METAL_FUNC ulong2 elem_to_loc_broadcast( function METAL_FUNC (line 24) | METAL_FUNC ulong3 elem_to_loc_broadcast( FILE: mlx/backend/metal/kernels/steel/utils/integral_constant.h function namespace (line 10) | namespace mlx { FILE: mlx/backend/metal/kernels/steel/utils/type_traits.h function namespace (line 9) | namespace metal { FILE: mlx/backend/metal/kernels/ternary.h function ternary_v (line 9) | void ternary_v( function ternary_v2 (line 38) | void ternary_v2( function ternary_g_nd1 (line 63) | void ternary_g_nd1( function ternary_g_nd2 (line 79) | void ternary_g_nd2( function ternary_g_nd3 (line 97) | void ternary_g_nd3( FILE: mlx/backend/metal/kernels/ternary_ops.h type Select (line 5) | struct Select { FILE: mlx/backend/metal/kernels/unary.h function unary_v (line 4) | void unary_v( function unary_v2 (line 22) | void unary_v2( FILE: mlx/backend/metal/kernels/unary_ops.h type Abs (line 17) | struct Abs { function complex64_t (line 37) | complex64_t operator()(complex64_t x) { type ArcCos (line 42) | struct ArcCos { type ArcCosh (line 51) | struct ArcCosh { type ArcSin (line 58) | struct ArcSin { type ArcSinh (line 67) | struct ArcSinh { type ArcTan (line 74) | struct ArcTan { type ArcTanh (line 83) | struct ArcTanh { type BitwiseInvert (line 90) | struct BitwiseInvert { type Ceil (line 97) | struct Ceil { type Cos (line 131) | struct Cos { function complex64_t (line 137) | complex64_t operator()(complex64_t x) { type Cosh (line 144) | struct Cosh { function complex64_t (line 150) | complex64_t operator()(complex64_t x) { type Conjugate (line 157) | struct Conjugate { type Erf (line 163) | struct Erf { type ErfInv (line 170) | struct ErfInv { type Exp (line 177) | struct Exp { function complex64_t (line 182) | complex64_t operator()(complex64_t x) { type Expm1 (line 187) | struct Expm1 { type Floor (line 194) | struct Floor { type Imag (line 228) | struct Imag { type Log (line 234) | struct Log { function complex64_t (line 240) | complex64_t operator()(complex64_t x) { type Log2 (line 247) | struct Log2 { function complex64_t (line 253) | complex64_t operator()(complex64_t x) { type Log10 (line 259) | struct Log10 { function complex64_t (line 265) | complex64_t operator()(complex64_t x) { type Log1p (line 271) | struct Log1p { type LogicalNot (line 278) | struct LogicalNot { type Negative (line 285) | struct Negative { type Real (line 292) | struct Real { type Round (line 298) | struct Round { function complex64_t (line 303) | complex64_t operator()(complex64_t x) { type Sigmoid (line 308) | struct Sigmoid { type Sign (line 316) | struct Sign { function complex64_t (line 324) | complex64_t operator()(complex64_t x) { type Sin (line 333) | struct Sin { function complex64_t (line 339) | complex64_t operator()(complex64_t x) { type Sinh (line 346) | struct Sinh { function complex64_t (line 352) | complex64_t operator()(complex64_t x) { type Square (line 359) | struct Square { type Sqrt (line 366) | struct Sqrt { function complex64_t (line 372) | complex64_t operator()(complex64_t x) { type Rsqrt (line 384) | struct Rsqrt { function complex64_t (line 390) | complex64_t operator()(complex64_t x) { type Tan (line 395) | struct Tan { function complex64_t (line 401) | complex64_t operator()(complex64_t x) { type Tanh (line 410) | struct Tanh { function complex64_t (line 416) | complex64_t operator()(complex64_t x) { function i (line 426) | auto i = complex64_t{0.0, 1.0}; function i (line 432) | auto i = complex64_t{0.0, 1.0}; function i (line 438) | auto i = complex64_t{0.0, 1.0}; type ToFP8 (line 443) | struct ToFP8 { type FromFP8 (line 450) | struct FromFP8 { FILE: mlx/backend/metal/kernels/utils.h type half (line 13) | typedef half float16_t; function bool (line 73) | struct Limits { function complex64_t (line 79) | struct Limits { function OffsetT (line 205) | OffsetT offset{0} function index (line 206) | int index{0} function next (line 210) | void next(const constant int* shape, const constant int64_t* strides) { function next (line 223) | void next(int n, const constant int* shape, const constant int64_t* stri... function OffsetT (line 246) | OffsetT location() { function OffsetT (line 254) | OffsetT offset{0}; function T (line 307) | T ceildiv(T N, U M) { function log1p (line 312) | inline float log1p(float x) { function bfloat16_t (line 324) | inline bfloat16_t log1p(bfloat16_t x) { function complex64_t (line 336) | inline complex64_t log1p(complex64_t in) { FILE: mlx/backend/metal/logsumexp.cpp type mlx::core (line 10) | namespace mlx::core { FILE: mlx/backend/metal/matmul.cpp type mlx::core (line 21) | namespace mlx::core { function check_transpose (line 25) | std::tuple check_transpose( function array (line 43) | inline array function ensure_batch_contiguous (line 54) | inline std::tuple function steel_matmul_regular_axpby_nax (line 176) | void steel_matmul_regular_axpby_nax( function steel_matmul_regular_axpby (line 341) | void steel_matmul_regular_axpby( function steel_gemm_splitk_axpby (line 530) | void steel_gemm_splitk_axpby( function steel_gemm_splitk_axpby_nax (line 687) | void steel_gemm_splitk_axpby_nax( function steel_matmul_axpby (line 859) | void steel_matmul_axpby( function gemv_axbpy (line 1033) | void gemv_axbpy( function gemv (line 1172) | inline void gemv( function gather_mm_rhs (line 1844) | void gather_mm_rhs( function gather_mm_rhs_nax (line 1977) | void gather_mm_rhs_nax( function gather_mv (line 2120) | void gather_mv( function gather_mm (line 2237) | void gather_mm( function segmented_mm (line 2424) | void segmented_mm( FILE: mlx/backend/metal/matmul.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/metal/metal.cpp type mlx::core::metal (line 8) | namespace mlx::core::metal { function is_available (line 10) | bool is_available() { function start_capture (line 14) | void start_capture(std::string path, NS::Object* object) { function start_capture (line 39) | void start_capture(std::string path) { function stop_capture (line 44) | void stop_capture() { FILE: mlx/backend/metal/metal.h function namespace (line 11) | namespace mlx::core::metal { FILE: mlx/backend/metal/no_metal.cpp type mlx::core (line 8) | namespace mlx::core { type metal (line 10) | namespace metal { function is_available (line 12) | bool is_available() { function start_capture (line 16) | void start_capture(std::string) {} function stop_capture (line 17) | void stop_capture() {} type fast (line 27) | namespace fast { function CustomKernelFunction (line 29) | CustomKernelFunction metal_kernel( FILE: mlx/backend/metal/nojit_kernels.cpp type mlx::core (line 7) | namespace mlx::core { FILE: mlx/backend/metal/normalization.cpp type mlx::core::fast (line 11) | namespace mlx::core::fast { FILE: mlx/backend/metal/primitives.cpp type mlx::core (line 18) | namespace mlx::core { function arange_set_scalars (line 21) | void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { FILE: mlx/backend/metal/quantized.cpp type mlx::core (line 15) | namespace mlx::core { function get_quantized_kernel_wrapped (line 20) | auto get_quantized_kernel_wrapped( function get_qmm_nax_kernel_wrapped (line 37) | auto get_qmm_nax_kernel_wrapped( function array (line 53) | inline array function array (line 64) | inline array ensure_row_contiguous_matrix( function get_qmv_batch_limit (line 84) | inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { function add_strides_and_shapes (line 128) | inline int add_strides_and_shapes( function add_gather_strides_and_shapes (line 158) | inline int add_gather_strides_and_shapes( function qmv_quad (line 177) | void qmv_quad( function qmv (line 235) | void qmv( function qvm_split_k (line 298) | void qvm_split_k( function qvm (line 418) | void qvm( function qmm_nax (line 472) | void qmm_nax( function gather_qmm_nax (line 575) | void gather_qmm_nax( function qmm (line 679) | void qmm( function gather_qmm (line 773) | void gather_qmm( function gather_qmv (line 864) | void gather_qmv( function gather_qvm (line 930) | void gather_qvm( function gather_qmm_rhs_nax (line 988) | void gather_qmm_rhs_nax( function gather_qmm_rhs (line 1119) | void gather_qmm_rhs( function dispatch_qmv (line 1269) | void dispatch_qmv( function quantize_dequantize (line 1463) | void quantize_dequantize( FILE: mlx/backend/metal/reduce.cpp type mlx::core (line 15) | namespace mlx::core { type RowReduceArgs (line 19) | struct RowReduceArgs { method RowReduceArgs (line 36) | RowReduceArgs( method encode (line 58) | void encode(CommandEncoder& compute_encoder) { type ColReduceArgs (line 89) | struct ColReduceArgs { method ColReduceArgs (line 107) | ColReduceArgs( method ColReduceArgs (line 144) | ColReduceArgs(const array& intermediate) { method encode (line 154) | void encode(CommandEncoder& compute_encoder) { function safe_div (line 188) | inline auto safe_div(size_t n, size_t m) { function safe_divup (line 192) | inline auto safe_divup(size_t n, size_t m) { function is_64b_int (line 196) | inline bool is_64b_int(Dtype dtype) { function is_64b_dtype (line 200) | inline bool is_64b_dtype(Dtype dtype) { function get_kernel_reduce_ndim (line 204) | inline int get_kernel_reduce_ndim(int reduce_ndim) { function threadgroup_size_from_row_size (line 214) | inline int threadgroup_size_from_row_size(int row_size) { function output_grid_for_col_reduce (line 233) | inline auto output_grid_for_col_reduce( function remap_reduce_types (line 245) | std::pair remap_reduce_types( function init_reduce (line 289) | void init_reduce( function all_reduce_dispatch (line 312) | void all_reduce_dispatch( function row_reduce_small (line 393) | void row_reduce_small( function row_reduce_simple (line 449) | void row_reduce_simple( function row_reduce_looped (line 489) | void row_reduce_looped( function row_reduce_general_dispatch (line 539) | void row_reduce_general_dispatch( function strided_reduce_small (line 566) | void strided_reduce_small( function strided_reduce_longcolumn (line 632) | void strided_reduce_longcolumn( function strided_reduce_looped (line 743) | void strided_reduce_looped( function strided_reduce_2pass (line 808) | void strided_reduce_2pass( function strided_reduce_general_dispatch (line 919) | void strided_reduce_general_dispatch( FILE: mlx/backend/metal/reduce.h function namespace (line 9) | namespace mlx::core { FILE: mlx/backend/metal/resident.cpp type mlx::core::metal (line 5) | namespace mlx::core::metal { FILE: mlx/backend/metal/resident.h function namespace (line 7) | namespace mlx::core::metal { FILE: mlx/backend/metal/rope.cpp type mlx::core::fast (line 6) | namespace mlx::core::fast { FILE: mlx/backend/metal/scaled_dot_product_attention.cpp type mlx::core::fast (line 14) | namespace mlx::core::fast { function sdpa_full_self_attention_nax (line 18) | void sdpa_full_self_attention_nax( function sdpa_full_self_attention_metal (line 166) | void sdpa_full_self_attention_metal( function sdpa_vector (line 329) | void sdpa_vector( function sdpa_vector_2pass (line 418) | void sdpa_vector_2pass( FILE: mlx/backend/metal/scan.cpp type mlx::core (line 13) | namespace mlx::core { function scan_gpu_inplace (line 15) | void scan_gpu_inplace( FILE: mlx/backend/metal/slicing.cpp type mlx::core (line 12) | namespace mlx::core { function concatenate_gpu (line 14) | void concatenate_gpu( function array (line 45) | array compute_dynamic_offset( FILE: mlx/backend/metal/softmax.cpp type mlx::core (line 11) | namespace mlx::core { FILE: mlx/backend/metal/sort.cpp type mlx::core (line 11) | namespace mlx::core { function single_block_sort (line 15) | void single_block_sort( function multi_block_sort (line 114) | void multi_block_sort( function gpu_merge_sort (line 274) | void gpu_merge_sort( FILE: mlx/backend/metal/ternary.cpp type mlx::core (line 9) | namespace mlx::core { function ternary_op_gpu_inplace (line 11) | void ternary_op_gpu_inplace( function ternary_op_gpu (line 135) | void ternary_op_gpu( function ternary_op_gpu (line 148) | void ternary_op_gpu( FILE: mlx/backend/metal/ternary.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/metal/unary.cpp type mlx::core (line 14) | namespace mlx::core { function unary_op_gpu_inplace (line 16) | void unary_op_gpu_inplace( function unary_op_gpu (line 98) | void unary_op_gpu( function unary_op_gpu (line 107) | void unary_op_gpu( FILE: mlx/backend/metal/unary.h function namespace (line 7) | namespace mlx::core { FILE: mlx/backend/metal/utils.cpp type mlx::core (line 6) | namespace mlx::core { function type_to_name (line 8) | std::string type_to_name(const Dtype& t) { function type_to_name (line 57) | std::string type_to_name(const array& a) { function get_block_dims (line 61) | MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) { function get_2d_grid_dims (line 66) | MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { function get_2d_grid_dims (line 71) | MTL::Size FILE: mlx/backend/metal/utils.h function namespace (line 11) | namespace mlx::core { FILE: mlx/backend/no_cpu/compiled.cpp type mlx::core (line 6) | namespace mlx::core { type detail (line 11) | namespace detail { function compile_available_for_device (line 12) | bool compile_available_for_device(const Device& device) { FILE: mlx/backend/no_cpu/device_info.cpp type mlx::core::cpu (line 5) | namespace mlx::core::cpu { function is_available (line 7) | bool is_available() { function device_count (line 11) | int device_count() { FILE: mlx/backend/no_cpu/primitives.cpp type mlx::core (line 18) | namespace mlx::core { type fast (line 133) | namespace fast { type distributed (line 138) | namespace distributed { FILE: mlx/backend/no_gpu/allocator.cpp function get_memory_size (line 14) | size_t get_memory_size() { type mlx::core (line 19) | namespace mlx::core { type allocator (line 21) | namespace allocator { class CommonAllocator (line 23) | class CommonAllocator : public Allocator { method get_active_memory (line 30) | size_t get_active_memory() const { method get_peak_memory (line 33) | size_t get_peak_memory() const { method reset_peak_memory (line 36) | void reset_peak_memory() { method get_memory_limit (line 40) | size_t get_memory_limit() { method set_memory_limit (line 43) | size_t set_memory_limit(size_t limit) { method CommonAllocator (line 54) | CommonAllocator() : memory_limit_(0.8 * get_memory_size()) { function CommonAllocator (line 63) | CommonAllocator& common_allocator() { method get_active_memory (line 30) | size_t get_active_memory() const { method get_peak_memory (line 33) | size_t get_peak_memory() const { method reset_peak_memory (line 36) | void reset_peak_memory() { method get_memory_limit (line 40) | size_t get_memory_limit() { method set_memory_limit (line 43) | size_t set_memory_limit(size_t limit) { method CommonAllocator (line 54) | CommonAllocator() : memory_limit_(0.8 * get_memory_size()) { function Allocator (line 68) | Allocator& allocator() { function Buffer (line 79) | Buffer CommonAllocator::malloc(size_t size) { function get_active_memory (line 106) | size_t get_active_memory() { function get_peak_memory (line 109) | size_t get_peak_memory() { function reset_peak_memory (line 112) | void reset_peak_memory() { function set_memory_limit (line 115) | size_t set_memory_limit(size_t limit) { function get_memory_limit (line 118) | size_t get_memory_limit() { function get_cache_memory (line 123) | size_t get_cache_memory() { function set_cache_limit (line 126) | size_t set_cache_limit(size_t) { function set_wired_limit (line 129) | size_t set_wired_limit(size_t) { function clear_cache (line 132) | void clear_cache() {} FILE: mlx/backend/no_gpu/apple_memory.h function get_memory_size (line 9) | size_t get_memory_size() { FILE: mlx/backend/no_gpu/device_info.cpp type mlx::core::gpu (line 5) | namespace mlx::core::gpu { function is_available (line 7) | bool is_available() { function device_count (line 11) | int device_count() { FILE: mlx/backend/no_gpu/eval.cpp type mlx::core::gpu (line 8) | namespace mlx::core::gpu { function new_stream (line 10) | void new_stream(Stream) {} function eval (line 12) | void eval(array&) { function finalize (line 16) | void finalize(Stream) { function synchronize (line 20) | void synchronize(Stream) { FILE: mlx/backend/no_gpu/event.cpp type mlx::core (line 9) | namespace mlx::core { type EventCounter (line 11) | struct EventCounter { FILE: mlx/backend/no_gpu/fence.cpp type mlx::core (line 9) | namespace mlx::core { type FenceImpl (line 11) | struct FenceImpl { FILE: mlx/backend/no_gpu/linux_memory.h function get_memory_size (line 9) | size_t get_memory_size() { FILE: mlx/backend/no_gpu/primitives.cpp type mlx::core (line 24) | namespace mlx::core { type fast (line 164) | namespace fast { type distributed (line 177) | namespace distributed { FILE: mlx/compile.cpp type mlx::core (line 21) | namespace mlx::core { function is_unary (line 26) | bool is_unary(const Primitive& p) { function is_binary (line 47) | bool is_binary(const Primitive& p) { function is_ternary (line 61) | bool is_ternary(const Primitive& p) { function is_broadcast (line 65) | bool is_broadcast(const Primitive& p) { function is_noop (line 69) | bool is_noop(const Primitive& p) { function is_reduction (line 73) | bool is_reduction(const Primitive& p) { function is_fusable (line 77) | bool is_fusable(const Primitive& p) { type detail (line 214) | namespace detail { function merge_one (line 230) | void merge_one(array& dst, array& src, ParentsMap& parents_map) { function merge (line 256) | void merge(array& dst, array& src, ParentsMap& parents_map) { function array (line 268) | array split_one( function get_function_address (line 291) | std::uintptr_t get_function_address(const std::function& fu... class CompilerCache (line 300) | class CompilerCache { type CacheEntry (line 302) | struct CacheEntry { method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless) method CacheEntry (line 317) | CacheEntry& find( method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless) method erase (line 369) | void erase(std::uintptr_t fun_id) { method clear (line 373) | void clear() { method CompilerCache (line 378) | CompilerCache() { function CompilerCache (line 388) | CompilerCache& compiler_cache() { type CacheEntry (line 302) | struct CacheEntry { method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless) method CacheEntry (line 317) | CacheEntry& find( method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless) method erase (line 369) | void erase(std::uintptr_t fun_id) { method clear (line 373) | void clear() { method CompilerCache (line 378) | CompilerCache() { function compile_trace (line 393) | std::tuple, std::vector, std::shared_ptr, ParentsMap> compile_dfs( function splitmix64 (line 544) | static inline uint64_t splitmix64(uint64_t x) noexcept { type VecU64Hash (line 551) | struct VecU64Hash { function compile_simplify (line 564) | void compile_simplify( function compile_fuse (line 779) | void compile_fuse( function compile_replace (line 1021) | std::vector compile_replace( function skip_compile (line 1088) | bool skip_compile() { function ArrayFnWithExtra (line 1093) | ArrayFnWithExtra compile( function compile (line 1160) | std::function(const std::vector&)> compile( function compile_erase (line 1187) | void compile_erase(std::uintptr_t fun_id) { function compile_clear_cache (line 1191) | void compile_clear_cache() { function compile (line 1197) | std::function(const std::vector&)> compile( function compile (line 1226) | std::function(const std::vector&)> compile( function disable_compile (line 1235) | void disable_compile() { function enable_compile (line 1239) | void enable_compile() { function set_compile_mode (line 1243) | void set_compile_mode(CompileMode mode) { FILE: mlx/compile.h function namespace (line 8) | namespace mlx::core { FILE: mlx/compile_impl.h function namespace (line 10) | namespace mlx::core::detail { FILE: mlx/device.cpp type mlx::core (line 9) | namespace mlx::core { function Device (line 11) | Device& mutable_default_device() { function Device (line 16) | const Device& default_device() { function set_default_device (line 20) | void set_default_device(const Device& d) { function is_available (line 36) | bool is_available(const Device& d) { function device_count (line 47) | int device_count(Device::DeviceType type) { FILE: mlx/device.h function namespace (line 11) | namespace mlx::core { FILE: mlx/distributed/distributed.cpp type mlx::core::distributed (line 13) | namespace mlx::core::distributed { type detail (line 15) | namespace detail { function Stream (line 17) | Stream communication_stream(Group group, StreamOrDevice s /* = {} */) { function all_sum (line 21) | void all_sum(Group group, const array& input, array& output, Stream ... function all_max (line 25) | void all_max(Group group, const array& input, array& output, Stream ... function all_min (line 29) | void all_min(Group group, const array& input, array& output, Stream ... function all_gather (line 33) | void all_gather(Group group, const array& input, array& output, Stre... function send (line 37) | void send(Group group, const array& input, int dst, Stream stream) { function recv (line 41) | void recv(Group group, array& out, int src, Stream stream) { function sum_scatter (line 45) | void sum_scatter( class EmptyGroup (line 53) | class EmptyGroup : public GroupImpl { method Stream (line 55) | Stream communication_stream(StreamOrDevice s) override { method rank (line 59) | int rank() override { method size (line 63) | int size() override { method split (line 67) | std::shared_ptr split(int color, int key = -1) override { method all_sum (line 71) | void all_sum(const array&, array&, Stream) override { method all_gather (line 75) | void all_gather(const array&, array&, Stream) override { method send (line 79) | void send(const array&, int, Stream) override { method recv (line 83) | void recv(array&, int, Stream) override { method all_max (line 88) | void all_max(const array&, array&, Stream) override { method all_min (line 93) | void all_min(const array&, array&, Stream) override { method sum_scatter (line 97) | void sum_scatter(const array&, array&, Stream) override { function is_available (line 105) | bool is_available() { function is_available (line 110) | bool is_available(const std::string& bk) { function Group (line 137) | Group Group::split(int color, int key /* = -1 */) const { function Group (line 141) | Group init(bool strict /* = false */, const std::string& bk /* = "any"... FILE: mlx/distributed/distributed.h function namespace (line 11) | namespace mlx::core::distributed { FILE: mlx/distributed/distributed_impl.h function namespace (line 7) | namespace mlx::core::distributed::detail { FILE: mlx/distributed/jaccl/jaccl.cpp type DeviceFile (line 18) | struct DeviceFile { method DeviceFile (line 19) | DeviceFile(const char* dev_file) { method size (line 60) | int size() { method is_valid_mesh (line 64) | bool is_valid_mesh() { method is_valid_ring (line 76) | bool is_valid_ring() { method extract_mesh_connectivity (line 101) | std::vector extract_mesh_connectivity(int rank) { method extract_ring_connectivity (line 111) | std::pair, std::vector> type mlx::core::distributed::jaccl (line 124) | namespace mlx::core::distributed::jaccl { function is_available (line 126) | bool is_available() { function init (line 130) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/jaccl/jaccl.h function namespace (line 5) | namespace mlx::core::distributed::jaccl { FILE: mlx/distributed/jaccl/mesh.cpp type mlx::core::distributed::jaccl (line 8) | namespace mlx::core::distributed::jaccl { FILE: mlx/distributed/jaccl/mesh.h function namespace (line 12) | namespace mlx::core::distributed::jaccl { FILE: mlx/distributed/jaccl/mesh_impl.h function namespace (line 11) | namespace mlx::core::distributed::jaccl { function all_gather (line 134) | void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) { function send (line 214) | void send(const char* in_ptr, int64_t n_bytes, int dst) { function recv (line 264) | void recv(char* out_ptr, int64_t n_bytes, int src) { function recv_from (line 317) | void recv_from(int sz, int rank, int buff) { function post_send_all (line 330) | void post_send_all(int sz, int buff) { function post_recv_all (line 341) | void post_recv_all(int sz, int buff) { FILE: mlx/distributed/jaccl/no_jaccl.cpp type mlx::core::distributed::jaccl (line 5) | namespace mlx::core::distributed::jaccl { function is_available (line 9) | bool is_available() { function init (line 13) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/jaccl/ring.cpp type mlx::core::distributed::jaccl (line 8) | namespace mlx::core::distributed::jaccl { FILE: mlx/distributed/jaccl/ring.h function namespace (line 11) | namespace mlx::core::distributed::jaccl { FILE: mlx/distributed/jaccl/ring_impl.h function namespace (line 11) | namespace mlx::core::distributed::jaccl { function all_gather (line 299) | void function send (line 405) | void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) { function recv (line 473) | void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) { function recv_from (line 551) | void recv_from(int sz, int buff, int left_right, int wire) { function post_recv_all (line 605) | void post_recv_all(int sz, int buff) { function post_send_all (line 618) | void post_send_all(int sz, int buff) { FILE: mlx/distributed/jaccl/utils.cpp type mlx::core::distributed::jaccl (line 34) | namespace mlx::core::distributed::jaccl { function IBVWrapper (line 69) | IBVWrapper& ibv() { function Destination (line 177) | const Destination& Connection::info() { function create_connections (line 257) | std::vector create_connections( FILE: mlx/distributed/jaccl/utils.h function namespace (line 48) | namespace mlx::core::distributed::jaccl { type Destination (line 90) | struct Destination { function class (line 100) | class SharedBuffer { function post_recv (line 198) | void post_recv(const SharedBuffer& buff, uint64_t work_request_id) { function poll (line 216) | int poll(int num_completions, ibv_wc* work_completions) { function poll (line 224) | inline int poll( function poll (line 247) | inline int poll( function class (line 268) | class SideChannel { FILE: mlx/distributed/mpi/mpi.cpp type mlx::core::distributed::mpi (line 35) | namespace mlx::core::distributed::mpi { function simple_sum (line 42) | void simple_sum( function simple_max (line 61) | void simple_max( function simple_min (line 81) | void simple_min( type MPIWrapper (line 100) | struct MPIWrapper { method MPIWrapper (line 101) | MPIWrapper() { method is_available (line 161) | bool is_available() { method init_safe (line 165) | bool init_safe() { method finalize_safe (line 195) | void finalize_safe() { method MPI_Comm (line 201) | MPI_Comm world() { method MPI_Datatype (line 205) | MPI_Datatype datatype(const array& arr) { method MPI_Op (line 240) | MPI_Op op_sum(const array& arr) { method MPI_Op (line 251) | MPI_Op op_max(const array& arr) { method MPI_Op (line 264) | MPI_Op op_min(const array& arr) { function MPIWrapper (line 339) | MPIWrapper& mpi() { method MPIWrapper (line 101) | MPIWrapper() { method is_available (line 161) | bool is_available() { method init_safe (line 165) | bool init_safe() { method finalize_safe (line 195) | void finalize_safe() { method MPI_Comm (line 201) | MPI_Comm world() { method MPI_Datatype (line 205) | MPI_Datatype datatype(const array& arr) { method MPI_Op (line 240) | MPI_Op op_sum(const array& arr) { method MPI_Op (line 251) | MPI_Op op_max(const array& arr) { method MPI_Op (line 264) | MPI_Op op_min(const array& arr) { class MPIGroup (line 346) | class MPIGroup : public GroupImpl { method MPIGroup (line 348) | MPIGroup(MPI_Comm comm, bool global) method Stream (line 359) | Stream communication_stream(StreamOrDevice s) override { method rank (line 363) | int rank() override { method size (line 370) | int size() override { method split (line 377) | std::shared_ptr split(int color, int key = -1) override { method all_sum (line 389) | void all_sum(const array& input, array& output, Stream stream) overr... method all_max (line 404) | void all_max(const array& input, array& output, Stream stream) overr... method all_min (line 419) | void all_min(const array& input, array& output, Stream stream) overr... method all_gather (line 434) | void all_gather(const array& input, array& output, Stream stream) ov... method send (line 449) | void send(const array& input, int dst, Stream stream) override { method recv (line 462) | void recv(array& out, int src, Stream stream) override { method sum_scatter (line 475) | void sum_scatter(const array& input, array& output, Stream stream) o... function is_available (line 486) | bool is_available() { function init (line 490) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/mpi/mpi.h function namespace (line 5) | namespace mlx::core::distributed::mpi { FILE: mlx/distributed/mpi/mpi_declarations.h type MPI_Status (line 22) | typedef struct ompi_status_public_t { FILE: mlx/distributed/mpi/no_mpi.cpp type mlx::core::distributed::mpi (line 5) | namespace mlx::core::distributed::mpi { function is_available (line 9) | bool is_available() { function init (line 13) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/nccl/nccl.cpp type mlx::core::distributed::nccl (line 27) | namespace mlx::core::distributed::nccl { type nccl_map (line 73) | struct nccl_map { type detail (line 87) | namespace detail { function dispatch_dtype (line 90) | void dispatch_dtype(const array& arr, F&& f) { function sendAll (line 102) | inline void sendAll(int sock, const void* buf, size_t len) { function recvAll (line 115) | inline void recvAll(int sock, void* buf, size_t len) { function bootstrap_unique_id (line 130) | inline void bootstrap_unique_id( function bootstrap_unique_id (line 258) | inline void bootstrap_unique_id( function get_env_var_or_throw (line 439) | std::string get_env_var_or_throw(const char* env_var_name, bool stri... type NCCLComm (line 271) | struct NCCLComm { method NCCLComm (line 276) | NCCLComm(ncclComm_t c, int rank, int size) method create (line 279) | static std::shared_ptr method split (line 286) | static std::shared_ptr split(NCCLComm* source, int color, ... method NCCLComm (line 297) | NCCLComm(const NCCLComm&) = delete; method NCCLComm (line 298) | NCCLComm& operator=(const NCCLComm&) = delete; class NCCLGroup (line 302) | class NCCLGroup : public GroupImpl { method NCCLGroup (line 304) | NCCLGroup(int worldRank, int worldSize, const std::string initMethod) method NCCLGroup (line 316) | NCCLGroup(std::shared_ptr comm, int rank, int size) method Stream (line 319) | Stream communication_stream(StreamOrDevice s) override { method rank (line 323) | int rank() override { method size (line 327) | int size() override { method all_sum (line 331) | void all_sum(const array& input, array& output, Stream stream) overr... method split (line 338) | std::shared_ptr split(int color, int key = -1) override { method all_gather (line 345) | void all_gather(const array& input, array& output, Stream stream) ov... method send (line 359) | void send(const array& input, int dst, Stream stream) override { method recv (line 363) | void recv(array& output, int src, Stream stream) override { method all_max (line 367) | void all_max(const array& input, array& output, Stream stream) overr... method all_min (line 374) | void all_min(const array& input, array& output, Stream stream) overr... method sum_scatter (line 381) | void sum_scatter(const array& input, array& output, Stream stream) o... method all_reduce_impl (line 389) | void all_reduce_impl( method reduce_scatter_impl (line 408) | void reduce_scatter_impl( function is_available (line 434) | bool is_available() { type detail (line 438) | namespace detail { function dispatch_dtype (line 90) | void dispatch_dtype(const array& arr, F&& f) { function sendAll (line 102) | inline void sendAll(int sock, const void* buf, size_t len) { function recvAll (line 115) | inline void recvAll(int sock, void* buf, size_t len) { function bootstrap_unique_id (line 130) | inline void bootstrap_unique_id( function bootstrap_unique_id (line 258) | inline void bootstrap_unique_id( function get_env_var_or_throw (line 439) | std::string get_env_var_or_throw(const char* env_var_name, bool stri... function init (line 455) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/nccl/nccl.h function namespace (line 5) | namespace mlx::core::distributed::nccl { FILE: mlx/distributed/nccl/no_nccl.cpp type mlx::core::distributed::nccl (line 5) | namespace mlx::core::distributed::nccl { function is_available (line 9) | bool is_available() { function init (line 13) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/ops.cpp type mlx::core::distributed (line 11) | namespace mlx::core::distributed { function Group (line 15) | Group to_group(std::optional group) { function array (line 25) | array all_sum( function array (line 43) | array all_max( function array (line 61) | array all_min( function array (line 79) | array all_gather( function array (line 103) | array send( function array (line 126) | array recv( function array (line 152) | array recv_like( function array (line 160) | array sum_scatter( FILE: mlx/distributed/ops.h function namespace (line 11) | namespace mlx::core::distributed { FILE: mlx/distributed/primitives.cpp type mlx::core::distributed (line 10) | namespace mlx::core::distributed { FILE: mlx/distributed/primitives.h function namespace (line 9) | namespace mlx::core::distributed { function class (line 24) | class AllReduce : public DistPrimitive { function class (line 70) | class AllGather : public DistPrimitive { function class (line 95) | class Send : public DistPrimitive { function class (line 114) | class Recv : public DistPrimitive { function class (line 130) | class ReduceScatter : public DistPrimitive { FILE: mlx/distributed/reduction_ops.h function namespace (line 3) | namespace mlx::core::distributed::detail { FILE: mlx/distributed/ring/no_ring.cpp type mlx::core::distributed::ring (line 5) | namespace mlx::core::distributed::ring { function is_available (line 9) | bool is_available() { function init (line 13) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/ring/ring.cpp type mlx::core::distributed::ring (line 90) | namespace mlx::core::distributed::ring { function log (line 105) | void log(std::ostream& os, T first) { function log (line 110) | void log(std::ostream& os, T first, Args... args) { function log_info (line 115) | void log_info(bool verbose, Args... args) { function ceildiv (line 124) | decltype(T() * U()) ceildiv(T a, U b) { class SocketThread (line 128) | class SocketThread { method SocketThread (line 130) | SocketThread(int fd) : fd_(fd), stop_(false) { method send (line 144) | std::future send(const T* buffer, size_t size) { method recv (line 149) | std::future recv(T* buffer, size_t size) { type SocketTask (line 154) | struct SocketTask { method SocketTask (line 155) | SocketTask(void* b, size_t s, std::promise&& p) method SocketTask (line 157) | SocketTask(SocketTask&& t) method send_impl (line 164) | std::future send_impl(const char* buffer, size_t size) { method recv_impl (line 181) | std::future recv_impl(char* buffer, size_t size) { method have_tasks (line 198) | bool have_tasks() { method worker (line 202) | void worker() { class CommunicationThreads (line 277) | class CommunicationThreads { method add (line 279) | void add(const std::vector& sockets) { method send (line 286) | std::future send(int socket, T* buffer, size_t size) { method recv (line 291) | std::future recv(int socket, T* buffer, size_t size) { function load_nodes (line 311) | std::vector> load_nodes(const char* hos... function accept_connections (line 331) | std::vector accept_connections( function make_connections (line 349) | std::vector make_connections( class RingGroup (line 381) | class RingGroup : public GroupImpl { method RingGroup (line 383) | RingGroup( method Stream (line 465) | Stream communication_stream(StreamOrDevice s) override { method rank (line 469) | int rank() override { method size (line 473) | int size() override { method all_sum (line 477) | void all_sum(const array& input, array& output, Stream stream) overr... method all_max (line 482) | void all_max(const array& input, array& output, Stream stream) overr... method all_min (line 487) | void all_min(const array& input, array& output, Stream stream) overr... method split (line 492) | std::shared_ptr split(int color, int key = -1) override { method all_gather (line 496) | void all_gather(const array& input, array& output, Stream stream) ov... method send (line 533) | void send(const array& input, int dst, Stream stream) override { method recv (line 554) | void recv(array& out, int src, Stream stream) override { method sum_scatter (line 578) | void sum_scatter(const array& input, array& output, Stream stream) o... method all_reduce (line 584) | void all_reduce( method all_reduce_impl (line 659) | void all_reduce_impl( method all_gather_impl (line 756) | void all_gather_impl( method send (line 792) | void method recv (line 811) | void recv(const std::vector& sockets, char* data, size_t data_s... function is_available (line 843) | bool is_available() { function init (line 847) | std::shared_ptr init(bool strict /* = false */) { FILE: mlx/distributed/ring/ring.h function namespace (line 5) | namespace mlx::core::distributed::ring { FILE: mlx/distributed/utils.cpp type mlx::core::distributed::detail (line 11) | namespace mlx::core::distributed::detail { function address_t (line 16) | address_t parse_address(const std::string& ip, const std::string& port) { function address_t (line 40) | address_t parse_address(const std::string& ip_port) { function TCPSocket (line 67) | TCPSocket& TCPSocket::operator=(TCPSocket&& s) { function TCPSocket (line 126) | TCPSocket TCPSocket::accept(const char* tag) { function TCPSocket (line 163) | TCPSocket TCPSocket::connect( FILE: mlx/distributed/utils.h function namespace (line 9) | namespace mlx::core::distributed::detail { FILE: mlx/dtype.cpp type mlx::core (line 7) | namespace mlx::core { function Dtype (line 86) | Dtype promote_types(const Dtype& t1, const Dtype& t2) { function kindof (line 91) | Dtype::Kind kindof(const Dtype& t) { class MLX_API (line 95) | class MLX_API class MLX_API (line 96) | class MLX_API class MLX_API (line 97) | class MLX_API class MLX_API (line 98) | class MLX_API class MLX_API (line 99) | class MLX_API class MLX_API (line 100) | class MLX_API class MLX_API (line 101) | class MLX_API class MLX_API (line 102) | class MLX_API class MLX_API (line 103) | class MLX_API class MLX_API (line 104) | class MLX_API class MLX_API (line 105) | class MLX_API class MLX_API (line 106) | class MLX_API class MLX_API (line 107) | class MLX_API class MLX_API (line 108) | class MLX_API function issubdtype (line 180) | bool issubdtype(const Dtype& a, const Dtype& b) { function issubdtype (line 184) | bool issubdtype(const Dtype::Category& cat, const Dtype& type) { function issubdtype (line 188) | bool issubdtype(const Dtype& type, const Dtype::Category& cat) { function issubdtype (line 192) | bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { FILE: mlx/dtype.h type Dtype (line 14) | struct Dtype { function Dtype (line 85) | inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64... FILE: mlx/dtype_utils.cpp type mlx::core (line 5) | namespace mlx::core { FILE: mlx/dtype_utils.h function namespace (line 10) | namespace mlx::core { FILE: mlx/einsum.cpp type mlx::core (line 10) | namespace mlx::core { type Subscript (line 24) | struct Subscript { method Subscript (line 25) | Subscript(std::string str, CharSet set) type PathInfo (line 31) | struct PathInfo { type PathNode (line 39) | struct PathNode { method PathNode (line 40) | PathNode( function parse (line 60) | std::pair, std::string> parse(std::string sub... function disjoint (line 107) | bool disjoint(const CharSet& x, const CharSet& y) { function term_size (line 117) | size_t term_size(const T& term, std::unordered_map di... function flop_count (line 125) | size_t flop_count( function compute_cost_and_scaling (line 141) | std::pair compute_cost_and_scaling( function greedy_path (line 161) | std::tuple, size_t, int> greedy_path( function can_dot (line 335) | bool can_dot(const std::vector& inputs, const Subscript& ou... function array (line 349) | array batch_tensordot( function array (line 424) | array collapse_repeats(array in, Subscript& subscript, StreamOrDevice ... function preprocess_einsum_inputs (line 490) | void preprocess_einsum_inputs( function array (line 537) | array einsum_naive( function einsum_path_helper (line 630) | std::pair, PathInfo> einsum_path_helper( function einsum_path (line 829) | std::pair>, std::string> einsum_path( function array (line 851) | array einsum( FILE: mlx/einsum.h function namespace (line 12) | namespace mlx::core { FILE: mlx/event.h function namespace (line 10) | namespace mlx::core { FILE: mlx/export.cpp function is_big_endian (line 23) | bool is_big_endian() { type mlx::core (line 28) | namespace mlx::core { type PrimitiveSerializer (line 35) | struct PrimitiveSerializer { method PrimitiveSerializer (line 41) | PrimitiveSerializer( type NotSerializable (line 90) | struct NotSerializable { type NotDeserializable (line 95) | struct NotDeserializable { function reverse_bytes (line 100) | void reverse_bytes(T& data) { function serialize (line 114) | void serialize(Writer& os, T v) { function T (line 146) | T deserialize(Reader& is) { type VariantType (line 183) | enum class VariantType { Int = 0, Float = 1, Bool = 2 } function serialize_variant (line 186) | void serialize_variant(Writer& os, T v) { function T (line 206) | T deserialize_variant(Reader& is) { function deserialize_tuple (line 222) | decltype(auto) deserialize_tuple(Reader& is, std::index_sequence) { function serialize (line 226) | void serialize(Writer& os, const Stream& s) { function Stream (line 232) | Stream deserialize(Reader& is) { function serialize (line 239) | void serialize(Writer& os, const Dtype& t) { function Dtype (line 245) | Dtype deserialize(Reader& is) { function serialize (line 251) | void serialize(Writer& os, const array& arr) { function array (line 256) | array deserialize(Reader& is) { function serialize_primitive (line 270) | void serialize_primitive(Writer& os, const Primitive& p) { function extract_state (line 277) | void extract_state(const T state, std::vector& unpacked_state) { function primitive_state (line 296) | std::vector primitive_state(const Primitive& p) { function deserialize_primitive (line 305) | std::shared_ptr deserialize_primitive(Reader& is, Stream s) { type PrimitiveFactory (line 321) | struct PrimitiveFactory { method PrimitiveFactory (line 453) | PrimitiveFactory() { method save (line 461) | void save(Writer& os, const std::shared_ptr& p) { method Stream (line 477) | Stream resolve_stream(const Stream& stream) { method load (line 494) | std::shared_ptr load(Reader& is) { method extract_state (line 505) | std::pair> extract_state( function write_header (line 523) | void write_header(Writer& os, int count, bool shapeless) { type FunctionTable (line 530) | struct FunctionTable { method FunctionTable (line 531) | FunctionTable(bool shapeless = false) : shapeless(shapeless) {} type Function (line 532) | struct Function { method Function (line 533) | Function( method Function (line 547) | Function(const Function&) = delete; method Function (line 548) | Function& operator=(const Function&) = delete; method Function (line 549) | Function(Function&&) = default; method Function (line 550) | Function() = default; method insert (line 558) | void insert( method print_functions (line 571) | void print_functions(std::ostream& os) { function FunctionExporter (line 891) | FunctionExporter exporter( function FunctionExporter (line 901) | FunctionExporter exporter( function FunctionExporter (line 911) | FunctionExporter exporter( function export_function (line 918) | void export_function( function export_function (line 926) | void export_function( function export_function (line 934) | void export_function( function FunctionExporter (line 943) | FunctionExporter exporter( function FunctionExporter (line 953) | FunctionExporter exporter( function FunctionExporter (line 963) | FunctionExporter exporter( function export_function (line 970) | void export_function( function export_function (line 978) | void export_function( function export_function (line 986) | void export_function( function ImportedFunction (line 1034) | ImportedFunction import_function(const std::string& file) { FILE: mlx/export.h function namespace (line 12) | namespace mlx::core { FILE: mlx/export_impl.h function namespace (line 8) | namespace mlx::core { FILE: mlx/fast.cpp type mlx::core::fast (line 11) | namespace mlx::core::fast { function array (line 53) | array rms_norm( function array (line 190) | array layer_norm( function array (line 367) | array rope( function array (line 531) | array rope( function array (line 560) | array rope( function array (line 613) | array scaled_dot_product_attention( FILE: mlx/fast.h function namespace (line 11) | namespace mlx::core::fast { FILE: mlx/fast_primitives.h function class (line 13) | class Custom : public Primitive { function eval_cpu (line 49) | void eval_cpu(const std::vector& inputs, std::vector& outp... function DEFINE_INPUT_OUTPUT_SHAPE (line 62) | DEFINE_NAME(RMSNorm) function DEFINE_NAME (line 86) | void eval_gpu(const std::vector& inputs, std::vector& outp... FILE: mlx/fence.h function namespace (line 7) | namespace mlx::core { FILE: mlx/fft.cpp type mlx::core::fft (line 10) | namespace mlx::core::fft { function array (line 12) | array fft_impl( function array (line 101) | array fft_impl( function array (line 117) | array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice... function array (line 123) | array fftn( function array (line 130) | array fftn( function array (line 136) | array fftn(const array& a, StreamOrDevice s /* = {} */) { function array (line 140) | array ifftn( function array (line 147) | array ifftn( function array (line 153) | array ifftn(const array& a, StreamOrDevice s /* = {} */) { function array (line 157) | array rfftn( function array (line 164) | array rfftn( function array (line 170) | array rfftn(const array& a, StreamOrDevice s /* = {} */) { function array (line 174) | array irfftn( function array (line 181) | array irfftn( function array (line 188) | array irfftn(const array& a, StreamOrDevice s /* = {} */) { function array (line 192) | array fftshift( function array (line 217) | array ifftshift( function array (line 244) | array fftshift(const array& a, StreamOrDevice s /* = {} */) { function array (line 253) | array ifftshift(const array& a, StreamOrDevice s /* = {} */) { FILE: mlx/fft.h function namespace (line 12) | namespace mlx::core::fft { FILE: mlx/graph_utils.cpp type mlx::core (line 13) | namespace mlx::core { function depth_first_traversal (line 37) | void depth_first_traversal( function print_graph (line 62) | void print_graph( function export_to_dot (line 105) | void export_to_dot( FILE: mlx/graph_utils.h function NodeNamer (line 12) | struct MLX_API NodeNamer { function print_graph (line 24) | inline void print_graph(std::ostream& os, const std::vector& outp... function export_to_dot (line 48) | inline void export_to_dot(std::ostream& os, const std::vector& ou... FILE: mlx/io.h function namespace (line 14) | namespace mlx::core { FILE: mlx/io/gguf.cpp type mlx::core (line 11) | namespace mlx::core { function dtype_to_gguf_tensor_type (line 16) | std::optional dtype_to_gguf_tensor_type(const Dtype& dtype) { function gguf_type_to_dtype (line 33) | std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { function Shape (line 50) | Shape get_shape(const gguf_tensor& tensor) { function extract_tensor_data (line 59) | std::tuple extract_tensor_data(gguf_tensor* ... function set_mx_value_from_gguf (line 90) | void set_mx_value_from_gguf( function load_metadata (line 203) | std::unordered_map load_metadata(gguf_ctx* ... function load_arrays (line 214) | std::unordered_map load_arrays(gguf_ctx* ctx) { function GGUFLoad (line 241) | GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { function append_kv_array (line 261) | void append_kv_array( function save_gguf (line 294) | void save_gguf( FILE: mlx/io/gguf.h function namespace (line 13) | namespace mlx::core { FILE: mlx/io/gguf_quants.cpp type mlx::core (line 9) | namespace mlx::core { function unpack_32_4 (line 11) | void unpack_32_4(uint8_t* data, int8_t* dst) { function extract_q4_0_data (line 32) | void extract_q4_0_data( function extract_q4_1_data (line 53) | void extract_q4_1_data( function extract_q8_0_data (line 75) | void extract_q8_0_data( function gguf_load_quantized (line 100) | void gguf_load_quantized( FILE: mlx/io/load.cpp type mlx::core (line 23) | namespace mlx::core { function is_big_endian (line 36) | inline bool is_big_endian() { function dtype_to_array_protocol (line 47) | std::string dtype_to_array_protocol(const Dtype& t) { function Dtype (line 59) | Dtype dtype_from_array_protocol(std::string_view t) { function pread (line 120) | int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { function save (line 144) | void save(std::shared_ptr out_stream, array a) { function save (line 219) | void save(std::string file, array a) { function array (line 229) | array load(std::shared_ptr in_stream, StreamOrDevice s) { function array (line 331) | array load(std::string file, StreamOrDevice s) { type io (line 335) | namespace io { function ThreadPool (line 337) | ThreadPool& thread_pool() { function ThreadPool (line 342) | ThreadPool& ParallelFileReader::thread_pool() { FILE: mlx/io/load.h function class (line 32) | class Reader { FILE: mlx/io/no_gguf.cpp type mlx::core (line 5) | namespace mlx::core { function GGUFLoad (line 7) | GGUFLoad load_gguf(const std::string&, StreamOrDevice s) { function save_gguf (line 12) | void save_gguf( FILE: mlx/io/no_safetensors.cpp type mlx::core (line 5) | namespace mlx::core { function SafetensorsLoad (line 7) | SafetensorsLoad load_safetensors(std::shared_ptr, StreamOr... function SafetensorsLoad (line 13) | SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) { function save_safetensors (line 19) | void save_safetensors( function save_safetensors (line 28) | void save_safetensors( FILE: mlx/io/safetensors.cpp type mlx::core (line 35) | namespace mlx::core { function dtype_to_safetensor_str (line 37) | std::string dtype_to_safetensor_str(Dtype t) { function Dtype (line 70) | Dtype dtype_from_safetensor_str(std::string_view str) { function SafetensorsLoad (line 106) | SafetensorsLoad load_safetensors( function SafetensorsLoad (line 162) | SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevi... function save_safetensors (line 166) | void save_safetensors( function save_safetensors (line 220) | void save_safetensors( FILE: mlx/linalg.cpp type mlx::core::linalg (line 11) | namespace mlx::core::linalg { function check_cpu_stream (line 13) | void check_cpu_stream(const StreamOrDevice& s, const std::string& pref... function check_float (line 21) | void check_float(Dtype dtype, const std::string& prefix) { function check_float_or_complex (line 30) | void check_float_or_complex(Dtype dtype, const std::string& prefix) { function Dtype (line 39) | Dtype at_least_float(const Dtype& d) { function array (line 43) | inline array l2_norm( function array (line 55) | inline array vector_norm( function array (line 80) | inline array matrix_norm( function array (line 136) | inline array matrix_norm( function array (line 165) | array norm( function array (line 181) | array norm( function array (line 204) | array norm( function qr (line 226) | std::pair qr(const array& a, StreamOrDevice s /* = {} */) { function svd (line 250) | std::vector function array (line 296) | array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { function array (line 319) | array inv(const array& a, StreamOrDevice s /* = {} */) { function array (line 323) | array tri_inv( function array (line 330) | array cholesky( function array (line 356) | array pinv(const array& a, StreamOrDevice s /* = {} */) { function array (line 401) | array cholesky_inv( function array (line 430) | array cross( function validate_eig (line 502) | void validate_eig( function array (line 521) | array eigvalsh( function eigh (line 535) | std::pair eigh( function array (line 549) | array eigvals(const array& a, StreamOrDevice s /* = {} */) { function eig (line 559) | std::pair eig(const array& a, StreamOrDevice s /* = {} *... function validate_lu (line 569) | void validate_lu( function lu_helper (line 586) | std::vector lu_helper(const array& a, StreamOrDevice s /* = {} ... function lu (line 602) | std::vector lu(const array& a, StreamOrDevice s /* = {} */) { function lu_factor (line 629) | std::pair lu_factor(const array& a, StreamOrDevice s /* ... function validate_solve (line 635) | void validate_solve( function array (line 682) | array solve(const array& a, const array& b, StreamOrDevice s /* = {} *... function array (line 698) | array solve_triangular( FILE: mlx/linalg.h function namespace (line 13) | namespace mlx::core::linalg { FILE: mlx/memory.h function namespace (line 9) | namespace mlx::core { FILE: mlx/ops.cpp type mlx::core (line 21) | namespace mlx::core { function compute_reduce_shape (line 25) | std::tuple, bool> compute_reduce_shape( function Dtype (line 57) | Dtype at_least_float(const Dtype& d) { function array (line 61) | array indices_or_default( function validate_quantized_input (line 75) | void validate_quantized_input( function extract_quantized_matmul_dims (line 117) | std::pair extract_quantized_matmul_dims( function array (line 150) | array arange( function array (line 189) | array arange( function array (line 196) | array arange( function array (line 203) | array arange(double start, double stop, StreamOrDevice s /* = {} */) { function array (line 206) | array arange(double stop, Dtype dtype, StreamOrDevice s /* = {} */) { function array (line 209) | array arange(double stop, StreamOrDevice s /* = {} */) { function array (line 212) | array arange(int start, int stop, int step, StreamOrDevice s /* = {} *... function array (line 220) | array arange(int start, int stop, StreamOrDevice s /* = {} */) { function array (line 228) | array arange(int stop, StreamOrDevice s /* = {} */) { function array (line 232) | array linspace( function array (line 258) | array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { function array (line 270) | array as_strided( function array (line 287) | array copy(array a, StreamOrDevice s /* = {} */) { function array (line 297) | array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { function array (line 305) | array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* =... function array (line 312) | array full(Shape shape, array vals, StreamOrDevice s /* = {} */) { function array (line 317) | array full_like( function array (line 326) | array full_like(const array& a, array vals, StreamOrDevice s /* = {} *... function array (line 330) | array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} ... function array (line 334) | array zeros_like(const array& a, StreamOrDevice s /* = {} */) { function array (line 338) | array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} *... function array (line 342) | array ones_like(const array& a, StreamOrDevice s /* = {} */) { function array (line 346) | array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} *... function array (line 366) | array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) { function array (line 370) | array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) { function array (line 376) | array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { function array (line 384) | array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { function array (line 392) | array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { function array (line 404) | array unflatten( function array (line 457) | array flatten( function array (line 496) | array flatten(const array& a, StreamOrDevice s /* = {} */) { function array (line 500) | array hadamard_transform( function array (line 529) | array squeeze_impl( function array (line 557) | array squeeze( function array (line 575) | array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) { function array (line 579) | array squeeze(const array& a, StreamOrDevice s /* = {} */) { function array (line 589) | array expand_dims_impl( function array (line 612) | array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} *... function array (line 616) | array expand_dims( function normalize_slice (line 645) | inline auto function normalize_dynamic_slice_inputs (line 701) | void normalize_dynamic_slice_inputs( function array (line 751) | array slice( function array (line 780) | array slice( function array (line 789) | array slice( function array (line 822) | array slice_update( function array (line 862) | array slice_update( function array (line 873) | array slice_update( function array (line 902) | array slice_update( function array (line 950) | array slice_update_add( function array (line 967) | array slice_update_add( function array (line 977) | array slice_update_prod( function array (line 994) | array slice_update_prod( function array (line 1004) | array slice_update_max( function array (line 1021) | array slice_update_max( function array (line 1031) | array slice_update_min( function array (line 1048) | array slice_update_min( function split (line 1058) | std::vector split( function split (line 1104) | std::vector function split (line 1109) | std::vector function split (line 1140) | std::vector function meshgrid (line 1145) | std::vector meshgrid( function array (line 1180) | array clip( function array (line 1198) | array concatenate( function array (line 1259) | array concatenate(std::vector arrays, StreamOrDevice s /* = {} ... function array (line 1267) | array stack( function array (line 1289) | array stack(const std::vector& arrays, StreamOrDevice s /* = {}... function array (line 1294) | array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { function array (line 1324) | array repeat(const array& arr, int repeats, StreamOrDevice s) { function array (line 1328) | array tile( function array (line 1358) | array edge_pad( function array (line 1404) | array pad( function array (line 1458) | array pad( function array (line 1478) | array pad( function array (line 1492) | array pad( function array (line 1506) | array moveaxis( function array (line 1533) | array swapaxes( function array (line 1556) | array transpose( function array (line 1595) | array transpose(const array& a, StreamOrDevice s /* = {} */) { function array (line 1601) | array broadcast_to( function broadcast_arrays (line 1630) | std::vector broadcast_arrays( function broadcast_arrays (line 1700) | std::vector broadcast_arrays( function broadcast_arrays (line 1751) | std::pair function broadcast_arrays (line 1757) | std::pair broadcast_arrays( function array (line 1766) | array equal(const array& a, const array& b, StreamOrDevice s /* = {} *... function array (line 1774) | array not_equal(const array& a, const array& b, StreamOrDevice s /* = ... function array (line 1785) | array greater(const array& a, const array& b, StreamOrDevice s /* = {}... function array (line 1793) | array greater_equal( function array (line 1807) | array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { function array (line 1815) | array less_equal(const array& a, const array& b, StreamOrDevice s /* =... function array (line 1826) | array array_equal( function array (line 1847) | array isnan(const array& a, StreamOrDevice s /* = {} */) { function array (line 1854) | array isinf(const array& a, StreamOrDevice s /* = {} */) { function array (line 1861) | array isfinite(const array& a, StreamOrDevice s /* = {} */) { function array (line 1868) | array isposinf(const array& a, StreamOrDevice s /* = {} */) { function array (line 1875) | array isneginf(const array& a, StreamOrDevice s /* = {} */) { function array (line 1882) | array where( function array (line 1899) | array nan_to_num( function array (line 1933) | array allclose( function array (line 1943) | array isclose( function array (line 1984) | array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 1990) | array all( function array (line 2010) | array all( function array (line 2018) | array any(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2024) | array any( function array (line 2044) | array any( function array (line 2052) | array sum(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2058) | array sum( function array (line 2089) | array sum( function array (line 2097) | array mean(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2103) | array mean( function array (line 2122) | array mean( function array (line 2130) | array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2136) | array median( function array (line 2203) | array median( function array (line 2211) | array var( function array (line 2221) | array var( function array (line 2248) | array var( function array (line 2257) | array std( function array (line 2267) | array std( function array (line 2276) | array std( function array (line 2285) | array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2291) | array prod( function array (line 2322) | array prod( function array (line 2330) | array max(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2336) | array max( function array (line 2359) | array max( function array (line 2367) | array min(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { function array (line 2373) | array min( function array (line 2399) | array min( function array (line 2407) | array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} *... function array (line 2419) | array argmin( function array (line 2444) | array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} *... function array (line 2456) | array argmax( function array (line 2481) | array bartlett(int M, StreamOrDevice s /* = {} */) { function array (line 2496) | array hanning(int M, StreamOrDevice s /* = {} */) { function array (line 2509) | array hamming(int M, StreamOrDevice s /* = {} */) { function array (line 2530) | array blackman(int M, StreamOrDevice s /* = {} */) { function array (line 2558) | array sort(const array& a, StreamOrDevice s /* = {} */) { function array (line 2564) | array sort(const array& a, int axis, StreamOrDevice s /* = {} */) { function array (line 2579) | array argsort(const array& a, StreamOrDevice s /* = {} */) { function array (line 2585) | array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) { function array (line 2603) | array partition(const array& a, int kth, StreamOrDevice s /* = {} */) { function array (line 2612) | array partition( function array (line 2644) | array argpartition(const array& a, int kth, StreamOrDevice s /* = {} *... function array (line 2653) | array argpartition( function array (line 2682) | array topk(const array& a, int k, StreamOrDevice s /* = {}*/) { function array (line 2688) | array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { function array (line 2716) | array logsumexp(const array& a, bool keepdims, StreamOrDevice s /* = {... function array (line 2722) | array logsumexp( function array (line 2770) | array logsumexp( function array (line 2778) | array abs(const array& a, StreamOrDevice s /* = {} */) { function array (line 2787) | array negative(const array& a, StreamOrDevice s /* = {} */) { function array (line 2795) | array operator-(const array& a) { function array (line 2799) | array sign(const array& a, StreamOrDevice s /* = {} */) { function array (line 2803) | array logical_not(const array& a, StreamOrDevice s /* = {} */) { function array (line 2811) | array logical_and(const array& a, const array& b, StreamOrDevice s /* ... function array (line 2821) | array operator&&(const array& a, const array& b) { function array (line 2825) | array logical_or(const array& a, const array& b, StreamOrDevice s /* =... function array (line 2835) | array operator||(const array& a, const array& b) { function array (line 2839) | array reciprocal(const array& a, StreamOrDevice s /* = {} */) { function array (line 2844) | array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { function array (line 2853) | array operator+(const array& a, const array& b) { function array (line 2857) | array subtract(const array& a, const array& b, StreamOrDevice s /* = {... function array (line 2869) | array operator-(const array& a, const array& b) { function array (line 2873) | array multiply(const array& a, const array& b, StreamOrDevice s /* = {... function array (line 2885) | array operator*(const array& a, const array& b) { function array (line 2889) | array divide(const array& a, const array& b, StreamOrDevice s /* = {} ... function array (line 2897) | array operator/(const array& a, const array& b) { function array (line 2900) | array operator/(double a, const array& b) { function array (line 2903) | array operator/(const array& a, double b) { function array (line 2907) | array floor_divide( function array (line 2922) | array remainder(const array& a, const array& b, StreamOrDevice s /* = ... function array (line 2933) | array operator%(const array& a, const array& b) { function divmod (line 2937) | std::vector function array (line 2952) | array maximum(const array& a, const array& b, StreamOrDevice s /* = {}... function array (line 2964) | array minimum(const array& a, const array& b, StreamOrDevice s /* = {}... function array (line 2976) | array floor(const array& a, StreamOrDevice s /* = {} */) { function array (line 2984) | array ceil(const array& a, StreamOrDevice s /* = {} */) { function array (line 2991) | array square(const array& a, StreamOrDevice s /* = {} */) { function array (line 2996) | array exp(const array& a, StreamOrDevice s /* = {} */) { function array (line 3002) | array expm1(const array& a, StreamOrDevice s /* = {} */) { function array (line 3009) | array sin(const array& a, StreamOrDevice s /* = {} */) { function array (line 3015) | array cos(const array& a, StreamOrDevice s /* = {} */) { function array (line 3021) | array tan(const array& a, StreamOrDevice s /* = {} */) { function array (line 3027) | array arcsin(const array& a, StreamOrDevice s /* = {} */) { function array (line 3034) | array arccos(const array& a, StreamOrDevice s /* = {} */) { function array (line 3041) | array arctan(const array& a, StreamOrDevice s /* = {} */) { function array (line 3048) | array arctan2(const array& a, const array& b, StreamOrDevice s /* = {}... function array (line 3056) | array sinh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3062) | array cosh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3068) | array tanh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3074) | array arcsinh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3081) | array arccosh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3088) | array arctanh(const array& a, StreamOrDevice s /* = {} */) { function array (line 3095) | array degrees(const array& a, StreamOrDevice s /* = {} */) { function array (line 3100) | array radians(const array& a, StreamOrDevice s /* = {} */) { function array (line 3105) | array log(const array& a, StreamOrDevice s /* = {} */) { function array (line 3115) | array log2(const array& a, StreamOrDevice s /* = {} */) { function array (line 3125) | array log10(const array& a, StreamOrDevice s /* = {} */) { function array (line 3135) | array log1p(const array& a, StreamOrDevice s /* = {} */) { function array (line 3142) | array logaddexp(const array& a, const array& b, StreamOrDevice s /* = ... function array (line 3155) | array sigmoid(const array& a, StreamOrDevice s /* = {} */) { function array (line 3162) | array erf(const array& a, StreamOrDevice s /* = {} */) { function array (line 3171) | array erfinv(const array& a, StreamOrDevice s /* = {} */) { function array (line 3180) | array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { function array (line 3185) | array round(const array& a, int decimals, StreamOrDevice s /* = {} */) { function array (line 3200) | array matmul( function array (line 3277) | array gather( function array (line 3365) | array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) { function array (line 3393) | array take( function array (line 3438) | array take(const array& a, const array& indices, StreamOrDevice s /* =... function array (line 3442) | array take(const array& a, int index, int axis, StreamOrDevice s /* = ... function array (line 3468) | array take(const array& a, int index, StreamOrDevice s /* = {} */) { function array (line 3472) | array take_along_axis( function array (line 3506) | array scatter_axis( function array (line 3559) | array put_along_axis( function array (line 3568) | array scatter_add_axis( function array (line 3578) | array scatter( function array (line 3672) | array scatter( function array (line 3681) | array scatter_add( function array (line 3690) | array scatter_prod( function array (line 3699) | array scatter_max( function array (line 3708) | array scatter_min( function array (line 3717) | array masked_scatter( function array (line 3800) | array sqrt(const array& a, StreamOrDevice s /* = {} */) { function array (line 3809) | array rsqrt(const array& a, StreamOrDevice s /* = {} */) { function array (line 3818) | array softmax( function array (line 3862) | array softmax( function array (line 3871) | array power(const array& a, const array& b, StreamOrDevice s /* = {} *... function array (line 3881) | array cumsum( function array (line 3904) | array cumsum( function array (line 3912) | array cumprod( function array (line 3934) | array cumprod( function array (line 3942) | array cummax( function array (line 3964) | array cummax( function array (line 3972) | array cummin( function array (line 3994) | array cummin( function array (line 4002) | array logcumsumexp( function array (line 4024) | array logcumsumexp( function run_conv_checks (line 4037) | inline void function array (line 4099) | array conv1d( function array (line 4120) | array conv2d( function array (line 4142) | array conv3d( function array (line 4166) | array conv_transpose_general( function array (line 4204) | array conv_transpose1d( function array (line 4218) | array conv_transpose2d( function array (line 4239) | array conv_transpose3d( function array (line 4262) | array conv_general( function quantization_params_from_mode (line 4374) | std::pair quantization_params_from_mode( function validate_mode_with_type (line 4403) | std::pair validate_mode_with_type( function validate_global_scale (line 4455) | void validate_global_scale( function array (line 4483) | array quantized_matmul( function validate_qqmm_inputs (line 4535) | void validate_qqmm_inputs( function extract_qqmm_dims (line 4595) | std::pair extract_qqmm_dims( function array (line 4625) | array qqmm( function array (line 4694) | array pack_and_quantize( function affine_quantize (line 4741) | std::vector function fp_quantize (line 4807) | std::vector fp_quantize( function quantize (line 4914) | std::vector quantize( function array (line 4960) | array affine_dequantize( function array (line 5054) | array fp_dequantize( function array (line 5177) | array dequantize( function array (line 5238) | array from_fp8(array x, Dtype dtype, StreamOrDevice s) { function array (line 5258) | array to_fp8(array x, StreamOrDevice s) { function array (line 5272) | array gather_qmm( function array (line 5369) | array tensordot( function array (line 5392) | array tensordot( function array (line 5454) | array outer(const array& a, const array& b, StreamOrDevice s /* = {} *... function array (line 5459) | array inner(const array& a, const array& b, StreamOrDevice s /* = {} *... function array (line 5472) | array addmm( function array (line 5614) | array block_masked_mm( function array (line 5791) | array gather_mm( function array (line 5895) | array segmented_mm( function array (line 5943) | array diagonal( function array (line 5996) | array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} *... function array (line 6019) | array trace( function array (line 6061) | array trace( function array (line 6070) | array trace(const array& a, StreamOrDevice s /* = {} */) { function depends (line 6075) | std::vector depends( function array (line 6100) | array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { function atleast_1d (line 6107) | std::vector atleast_1d( function array (line 6118) | array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { function atleast_2d (line 6129) | std::vector atleast_2d( function array (line 6140) | array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { function atleast_3d (line 6153) | std::vector atleast_3d( function array (line 6164) | array number_of_elements( function array (line 6196) | array conjugate(const array& a, StreamOrDevice s /* = {} */) { function array (line 6205) | array bitwise_impl( function array (line 6231) | array bitwise_and(const array& a, const array& b, StreamOrDevice s /* ... function array (line 6234) | array operator&(const array& a, const array& b) { function array (line 6238) | array bitwise_or(const array& a, const array& b, StreamOrDevice s /* =... function array (line 6241) | array operator|(const array& a, const array& b) { function array (line 6245) | array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* ... function array (line 6248) | array operator^(const array& a, const array& b) { function array (line 6252) | array left_shift(const array& a, const array& b, StreamOrDevice s /* =... function array (line 6259) | array operator<<(const array& a, const array& b) { function array (line 6263) | array right_shift(const array& a, const array& b, StreamOrDevice s /* ... function array (line 6276) | array operator>>(const array& a, const array& b) { function array (line 6280) | array bitwise_invert(const array& a, StreamOrDevice s /* = {} */) { function array (line 6291) | array operator~(const array& a) { function array (line 6295) | array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {... function array (line 6323) | array roll( function array (line 6367) | array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { function array (line 6375) | array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {... function array (line 6383) | array roll(const array& a, int shift, int axis, StreamOrDevice s /* = ... function array (line 6387) | array roll( function array (line 6396) | array roll( function array (line 6408) | array real(const array& a, StreamOrDevice s /* = {} */) { function array (line 6415) | array imag(const array& a, StreamOrDevice s /* = {} */) { function array (line 6422) | array contiguous( FILE: mlx/ops.h function namespace (line 13) | namespace mlx::core { FILE: mlx/primitives.cpp type mlx::core (line 20) | namespace mlx::core { function vmap_binary_op (line 24) | std::tuple vmap_binary_op( function vmap_ternary_op (line 60) | std::tuple vmap_ternary_op( function array (line 121) | array gather_mm_grad( function broadcast_vjp (line 805) | std::vector function Shape (line 859) | Shape Broadcast::output_shape(const std::vector& inputs) { function Shape (line 908) | Shape BroadcastAxes::output_shape( function array (line 1175) | array conv_weight_backward_patches( function conv_out_axis_size (line 1248) | inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int ... function dilate_size (line 1253) | inline int dilate_size(int dim, int dil) { function Shape (line 1259) | Shape Convolution::conv_out_shape( function Shape (line 2021) | Shape ExpandDims::output_shape( function Shape (line 2080) | Shape Flatten::output_shape(const array& input, int start_axis, int en... function Shape (line 2134) | Shape Unflatten::output_shape( function quantization_mode_to_string (line 3330) | std::string quantization_mode_to_string(QuantizationMode mode) { function QuantizationMode (line 3344) | QuantizationMode string_to_quantization_mode( function Shape (line 3810) | Shape Reshape::output_shape(const array& input, Shape shape) { function Shape (line 5387) | Shape Squeeze::output_shape(const array& input, const std::vector... FILE: mlx/primitives.h function class (line 49) | class MLX_API Primitive { function eval_cpu (line 137) | inline void eval_cpu( function eval_gpu (line 142) | inline void eval_gpu( type class (line 155) | enum class function DEFINE_VMAP (line 167) | void eval_gpu(const std::vector& inputs, array& out) override; function DEFINE_VMAP (line 308) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 380) | class ArgSort : public UnaryPrimitive { type Op (line 450) | enum Op { And, Or, Xor, LeftShift, RightShift } function explicit (line 452) | explicit BitwiseBinary(Stream stream, Op op) function DEFINE_INPUT_OUTPUT_SHAPE (line 477) | bool is_equivalent(const Primitive& other) const override; function DEFINE_NAME (line 559) | void eval_gpu(const std::vector& inputs, array& out) override; function DEFINE_VMAP (line 618) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 671) | class Concatenate : public UnaryPrimitive { function class (line 692) | class Conjugate : public UnaryPrimitive { function DEFINE_VMAP (line 823) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 884) | class Depends : public Primitive { function DEFINE_VMAP (line 910) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 937) | class Select : public UnaryPrimitive { function DEFINE_VMAP (line 1026) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 1100) | class Flatten : public UnaryPrimitive { function DEFINE_VMAP (line 1130) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 1177) | class GatherAxis : public UnaryPrimitive { function DEFINE_VMAP (line 1281) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 1313) | class Log : public UnaryPrimitive { function DEFINE_VMAP (line 1408) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 1957) | class MaskedScatter : public UnaryPrimitive { function DEFINE_VMAP (line 1976) | void eval_gpu(const std::vector& inputs, array& out) override; function state (line 2159) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 2221) | class Square : public UnaryPrimitive { function DEFINE_VMAP (line 2284) | void eval_gpu(const std::vector& inputs, array& out) override; function DEFINE_VMAP (line 2337) | void eval_gpu(const std::vector& inputs, array& out) override; function DEFINE_VMAP (line 2377) | void eval_gpu(const std::vector& inputs, array& out) override; function class (line 2415) | class QRF : public Primitive { FILE: mlx/random.cpp type mlx::core::random (line 12) | namespace mlx::core::random { function array (line 20) | array KeySequence::next() { function seed (line 26) | void seed(uint64_t seed) { function array (line 30) | array key(uint64_t seed) { function array (line 36) | array bits( function split (line 75) | std::pair split(const array& key, StreamOrDevice s /* = ... function array (line 81) | array split(const array& key, int num, StreamOrDevice s /* = {} */) { function T (line 88) | T below_one() { function array (line 95) | array uniform( function array (line 143) | array uniform( function array (line 152) | inline array complex_normal( function array (line 174) | array normal( function array (line 206) | array multivariate_normal( function array (line 271) | array randint( function array (line 286) | array bernoulli( function array (line 311) | array bernoulli( function array (line 318) | array bernoulli( function array (line 324) | array truncated_normal( function array (line 351) | array truncated_normal( function array (line 361) | array gumbel( function get_valid_axis (line 371) | int get_valid_axis(int axis, int ndim) { function array (line 382) | array categorical_impl( function array (line 395) | array categorical( function array (line 418) | array categorical( function array (line 432) | array categorical( function array (line 443) | array laplace( function array (line 477) | array permutation( function array (line 485) | array permutation( FILE: mlx/random.h function namespace (line 13) | namespace mlx::core::random { FILE: mlx/scheduler.cpp type mlx::core (line 7) | namespace mlx::core { function Stream (line 9) | Stream default_stream(Device d) { function set_default_stream (line 17) | void set_default_stream(Stream s) { function Stream (line 25) | Stream get_stream(int index) { function get_streams (line 29) | std::vector get_streams() { function Stream (line 33) | Stream new_stream(Device d) { function Stream (line 41) | Stream new_stream() { function synchronize (line 45) | void synchronize(Stream s) { function synchronize (line 56) | void synchronize() { type scheduler (line 60) | namespace scheduler { function Scheduler (line 63) | Scheduler& scheduler() { FILE: mlx/scheduler.h function namespace (line 16) | namespace mlx::core::scheduler { function thread_fn (line 36) | void thread_fn() { function class (line 67) | class Scheduler { FILE: mlx/small_vector.h function namespace (line 37) | namespace mlx::core { function allocator_ (line 155) | allocator_(allocator) { function T (line 233) | T* data() { function T (line 236) | const T* data() const { function iterator (line 240) | iterator begin() { function iterator (line 247) | iterator end() { function const (line 262) | auto rbegin() { function const (line 269) | auto rend() { function T (line 290) | const T& front() const { function T (line 299) | const T& back() const { function T (line 310) | const T& at(size_t index) const { function T (line 318) | const T& operator[](size_t index) const { function push_back (line 332) | void push_back(T x) { function iterator (line 342) | iterator insert(iterator pos, T value) { function iterator (line 346) | iterator insert(iterator pos, size_t count, T value) { function reserve (line 431) | void reserve(size_t new_capacity) { function clear (line 438) | void clear() { function MLX_NOINLINE (line 473) | MLX_NOINLINE void free_storage() { function reset_to_inline_storage (line 482) | void reset_to_inline_storage() { function T (line 496) | T* inline_storage_begin() { function T (line 499) | const T* inline_storage_begin() const { type is_vector (line 526) | struct is_vector FILE: mlx/stream.h function namespace (line 10) | namespace mlx::core { FILE: mlx/threadpool.h function class (line 35) | class ThreadPool { function ThreadPool (line 55) | inline ThreadPool::ThreadPool(size_t threads) : stop(false) { function task (line 64) | auto task = std::make_shared>( function resize (line 82) | inline void ThreadPool::resize(size_t threads) { function ThreadPool (line 93) | inline ThreadPool::~ThreadPool() { function stop_and_wait (line 97) | inline void ThreadPool::stop_and_wait() { FILE: mlx/transforms.cpp type mlx::core (line 23) | namespace mlx::core { class Synchronizer (line 29) | class Synchronizer : public Primitive { method Synchronizer (line 31) | explicit Synchronizer(Stream stream) : Primitive(stream) {} method eval_cpu (line 33) | void eval_cpu(const std::vector&, std::vector&) overri... method eval_gpu (line 34) | void eval_gpu(const std::vector&, std::vector&) overri... function array (line 52) | array eval_impl(std::vector outputs, bool async) { function async_eval (line 296) | void async_eval(std::vector outputs) { function eval (line 310) | void eval(std::vector outputs) { function vjp (line 327) | std::pair, std::vector> vjp( function vjp (line 506) | std::pair, std::vector> vjp( function vjp (line 515) | std::pair vjp( function jvp (line 526) | std::pair, std::vector> jvp( function jvp (line 641) | std::pair jvp( function ValueAndGradFn (line 652) | ValueAndGradFn value_and_grad( type detail (line 692) | namespace detail { function vmap_trace (line 694) | std::pair, std::vector> vmap_trace( function vmap_replace (line 754) | std::vector vmap_replace( function vmap (line 886) | std::function(const std::vector&)> vmap( function vmap (line 919) | std::function vmap( function vmap (line 933) | std::function vmap( function custom_function (line 946) | std::function(const std::vector&)> custom_fu... function custom_vjp (line 1043) | std::function(const std::vector&)> custom_vjp( function checkpoint (line 1052) | std::function(const std::vector&)> checkpoint( FILE: mlx/transforms.h function SimpleValueAndGradFn (line 102) | SimpleValueAndGradFn inline value_and_grad( FILE: mlx/transforms_impl.h function namespace (line 7) | namespace mlx::core::detail { function retain_graph (line 51) | struct RetainGraph { function in_tracing (line 69) | inline bool in_tracing() { function in_dynamic_tracing (line 75) | inline bool in_dynamic_tracing() { function in_grad_tracing (line 80) | inline bool in_grad_tracing() { function retain_graph (line 84) | inline bool retain_graph() { FILE: mlx/types/bf16.h function namespace (line 13) | namespace mlx::core { FILE: mlx/types/complex.h function namespace (line 7) | namespace mlx::core { FILE: mlx/types/fp16.h function namespace (line 13) | namespace mlx::core { FILE: mlx/types/half_types.h function namespace (line 8) | namespace mlx::core { function namespace (line 16) | namespace mlx::core { function namespace (line 25) | namespace mlx::core { function namespace (line 33) | namespace mlx::core { FILE: mlx/types/limits.h type numeric_limits (line 13) | struct numeric_limits type numeric_limits (line 16) | struct numeric_limits type numeric_limits (line 19) | struct numeric_limits function float16_t (line 25) | constexpr static float16_t bits_to_half(uint16_t v) { function bfloat16_t (line 45) | struct numeric_limits { FILE: mlx/utils.cpp type mlx::core (line 12) | namespace mlx::core { function Stream (line 14) | Stream to_stream(StreamOrDevice s) { function Stream (line 24) | Stream to_stream(StreamOrDevice s, Device default_) { function PrintFormatter (line 80) | PrintFormatter& get_global_formatter() { function abort_with_exception (line 85) | void abort_with_exception(const std::exception& error) { function Dtype (line 92) | Dtype result_type(const std::vector& arrays) { function Shape (line 100) | Shape broadcast_shapes(const Shape& s1, const Shape& s2) { function normalize_axis_index (line 133) | int normalize_axis_index( function print_subarray (line 180) | void print_subarray(std::ostream& os, const array& a, size_t index, in... function print_array (line 206) | void print_array(std::ostream& os, const array& a) { type env (line 251) | namespace env { function get_var (line 253) | int get_var(const char* name, int default_value) { function get_var (line 261) | std::string get_var(const char* name, const char* default_value) { function set_finfo_limits (line 272) | void set_finfo_limits(double& min, double& max, double& eps) { function set_iinfo_limits (line 299) | void set_iinfo_limits(int64_t& min, uint64_t& max) { FILE: mlx/utils.h function namespace (line 14) | namespace mlx::core { type PrintFormatter (line 41) | struct PrintFormatter { function finfo (line 64) | struct MLX_API finfo { function iinfo (line 73) | struct MLX_API iinfo { function Dtype (line 81) | inline Dtype result_type(const array& a, const array& b) { function Dtype (line 84) | inline Dtype result_type(const array& a, const array& b, const array& c) { function is_power_of_2 (line 125) | inline bool is_power_of_2(int n) { function next_power_of_2 (line 129) | inline int next_power_of_2(int n) { function namespace (line 136) | namespace env { FILE: mlx/version.cpp type mlx::core (line 5) | namespace mlx::core { FILE: mlx/version.h function namespace (line 13) | namespace mlx::core { FILE: python/mlx/__main__.py function main (line 4) | def main() -> None: FILE: python/mlx/_distributed_utils/common.py class Host (line 13) | class Host: class Hostfile (line 21) | class Hostfile: method to_json (line 26) | def to_json(self): method from_file (line 37) | def from_file(cls, hostfile): method from_list (line 91) | def from_list(cls, hostlist, repeats=1): class OptionalBoolAction (line 106) | class OptionalBoolAction(argparse.Action): method __call__ (line 107) | def __call__(self, parser, namespace, values, option_string=None): function positive_number (line 114) | def positive_number(x): function log (line 121) | def log(verbose, *args, **kwargs): function log_warning (line 128) | def log_warning(*args, **kwargs): function log_error (line 133) | def log_error(*args, **kwargs): FILE: python/mlx/_distributed_utils/config.py class SSHInfo (line 26) | class SSHInfo: method __bool__ (line 30) | def __bool__(self): class ThunderboltPort (line 35) | class ThunderboltPort: class ThunderboltHost (line 42) | class ThunderboltHost: function add_ips (line 47) | def add_ips(hosts, verbose=False): function save_hostfile (line 72) | def save_hostfile(args, hostfile): function check_rdma (line 82) | def check_rdma(hosts, verbose=False, strict=True): function can_auto_setup (line 111) | def can_auto_setup(hosts, sshinfo, auto_setup=False): class IPConfigurator (line 123) | class IPConfigurator: method __init__ (line 124) | def __init__(self, hosts, tb_hosts, uuid_reverse_index): method setup (line 162) | def setup(self, verbose=False, auto_setup=False): function parse_hardware_ports (line 188) | def parse_hardware_ports(ports_string): function extract_connectivity (line 200) | def extract_connectivity(hosts, verbose): function make_connectivity_matrix (line 265) | def make_connectivity_matrix(tb_hosts, uuid_reverse_index): function tb_connectivity_to_dot (line 277) | def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index): function extract_rings (line 308) | def extract_rings(connectivity): function check_valid_mesh (line 337) | def check_valid_mesh(hosts, connectivity, strict=True): function check_valid_ring (line 356) | def check_valid_ring(hosts, rings, strict=True): function check_ssh_connections (line 370) | def check_ssh_connections(hosts, ignore_unreachable=False): function prepare_ethernet_hostfile (line 434) | def prepare_ethernet_hostfile(args, hosts): function configure_ring (line 445) | def configure_ring(args, hosts, ips, ring, sshinfo): function configure_jaccl (line 465) | def configure_jaccl(args, hosts, ips, sshinfo): function configure_jaccl_ring (line 486) | def configure_jaccl_ring(args, hosts, ips, ring, sshinfo): function prepare_tb_hostfile (line 515) | def prepare_tb_hostfile(args, hosts, sshinfo): function main (line 568) | def main(): FILE: python/mlx/_distributed_utils/launch.py class CommandProcess (line 25) | class CommandProcess: method process (line 27) | def process(self): method exit_status (line 32) | def exit_status(self): method preprocess_output (line 37) | def preprocess_output(self, data: str, is_stdout=False): method terminate (line 42) | def terminate(self): class RemoteProcess (line 47) | class RemoteProcess(CommandProcess): method __init__ (line 48) | def __init__(self, rank, host, python, cwd, files, env, command): method process (line 69) | def process(self): method exit_status (line 73) | def exit_status(self): method preprocess_output (line 76) | def preprocess_output(self, data, is_stdout=False): method terminate (line 84) | def terminate(self): method make_launch_script (line 107) | def make_launch_script(rank, cwd, files, env, command, is_local): method make_kill_script (line 156) | def make_kill_script(pidfile): function _launch_with_io (line 170) | def _launch_with_io(command_class, arguments, verbose): function launch_ring (line 292) | def launch_ring(parser, hosts, args, command): function launch_nccl (line 327) | def launch_nccl(parser, hosts, args, command): function launch_jaccl (line 366) | def launch_jaccl(parser, hosts, args, command): function get_mpi_libname (line 396) | def get_mpi_libname(): function launch_mpi (line 418) | def launch_mpi(parser, hosts, args, command): function main (line 457) | def main(): FILE: python/mlx/_reprlib_fix.py function repr_array (line 9) | def repr_array(self, x, maxlevel): FILE: python/mlx/extension.py class CMakeExtension (line 18) | class CMakeExtension(Extension): method __init__ (line 19) | def __init__(self, name: str, sourcedir: str = "") -> None: class CMakeBuild (line 24) | class CMakeBuild(build_ext): method build_extension (line 25) | def build_extension(self, ext: CMakeExtension) -> None: method run (line 72) | def run(self) -> None: FILE: python/mlx/nn/init.py function constant (line 9) | def constant( function normal (line 37) | def normal( function uniform (line 68) | def uniform( function identity (line 99) | def identity(dtype: mx.Dtype = mx.float32) -> Callable[[mx.array], mx.ar... function _calculate_fan_in_fan_out (line 128) | def _calculate_fan_in_fan_out(x): function glorot_normal (line 149) | def glorot_normal( function glorot_uniform (line 192) | def glorot_uniform( function he_normal (line 235) | def he_normal( function he_uniform (line 293) | def he_uniform( function sparse (line 353) | def sparse( function orthogonal (line 400) | def orthogonal( FILE: python/mlx/nn/layers/activations.py function _make_activation_module (line 11) | def _make_activation_module(f): function sigmoid (line 20) | def sigmoid(x): function relu (line 30) | def relu(x): function relu2 (line 39) | def relu2(x): function relu6 (line 48) | def relu6(x): function leaky_relu (line 57) | def leaky_relu(x, negative_slope=0.01): function log_softmax (line 66) | def log_softmax(x, axis=-1): function elu (line 75) | def elu(x, alpha=1.0): function softmax (line 84) | def softmax(x, axis=-1): function softplus (line 93) | def softplus(x): function softsign (line 102) | def softsign(x): function softshrink (line 111) | def softshrink(x, lambd: float = 0.5): function celu (line 125) | def celu(x, alpha=1.0): function silu (line 135) | def silu(x): function log_sigmoid (line 145) | def log_sigmoid(x): function gelu (line 154) | def gelu(x) -> mx.array: function gelu_approx (line 169) | def gelu_approx(x): function gelu_fast_approx (line 186) | def gelu_fast_approx(x): function glu (line 207) | def glu(x: mx.array, axis: int = -1) -> mx.array: function step (line 224) | def step(x: mx.array, threshold: float = 0.0): function selu (line 244) | def selu(x): function prelu (line 261) | def prelu(x: mx.array, alpha: mx.array) -> mx.array: function mish (line 273) | def mish(x: mx.array) -> mx.array: function hardswish (line 288) | def hardswish(x): function hard_tanh (line 299) | def hard_tanh(x, min_val=-1.0, max_val=1.0): function hard_shrink (line 308) | def hard_shrink(x, lambd=0.5): function softmin (line 322) | def softmin(x, axis=-1): function tanh (line 330) | def tanh(x): class GLU (line 338) | class GLU(Module): method __init__ (line 351) | def __init__(self, axis: int = -1): method __call__ (line 355) | def __call__(self, x) -> Any: class Sigmoid (line 360) | class Sigmoid(Module): class Mish (line 369) | class Mish(Module): class ReLU (line 381) | class ReLU(Module): class ReLU2 (line 390) | class ReLU2(Module): class ReLU6 (line 398) | class ReLU6(Module): class LeakyReLU (line 405) | class LeakyReLU(Module): method __init__ (line 414) | def __init__(self, negative_slope=1e-2): method __call__ (line 418) | def __call__(self, x): class ELU (line 422) | class ELU(Module): method __init__ (line 432) | def __init__(self, alpha=1.0): method __call__ (line 436) | def __call__(self, x): class Softmax (line 441) | class Softmax(Module): class Softplus (line 449) | class Softplus(Module): class Softsign (line 457) | class Softsign(Module): class Softshrink (line 464) | class Softshrink(Module): method __init__ (line 473) | def __init__(self, lambd=0.5): method __call__ (line 477) | def __call__(self, x): class CELU (line 481) | class CELU(Module): method __init__ (line 492) | def __init__(self, alpha=1.0): method __call__ (line 496) | def __call__(self, x): class SiLU (line 501) | class SiLU(Module): class LogSoftmax (line 509) | class LogSoftmax(Module): class LogSigmoid (line 517) | class LogSigmoid(Module): class PReLU (line 524) | class PReLU(Module): method __init__ (line 536) | def __init__(self, num_parameters=1, init=0.25): method __call__ (line 540) | def __call__(self, x: mx.array): class GELU (line 544) | class GELU(Module): method __init__ (line 572) | def __init__(self, approx="none"): method __call__ (line 581) | def __call__(self, x): class Tanh (line 590) | class Tanh(Module): class Hardswish (line 598) | class Hardswish(Module): class Step (line 605) | class Step(Module): method __init__ (line 621) | def __init__(self, threshold: float = 0.0): method __call__ (line 625) | def __call__(self, x: mx.array): class SELU (line 630) | class SELU(Module): class HardTanh (line 638) | class HardTanh(Module): class HardShrink (line 646) | class HardShrink(Module): class Softmin (line 657) | class Softmin(Module): FILE: python/mlx/nn/layers/base.py class Module (line 12) | class Module(dict): method __init__ (line 61) | def __init__(self): method training (line 67) | def training(self): method state (line 72) | def state(self): method _extra_repr (line 84) | def _extra_repr(self) -> str: method __repr__ (line 87) | def __repr__(self): method __getattr__ (line 99) | def __getattr__(self, key: str): method __setattr__ (line 105) | def __setattr__(self, key: str, val: Any): method __delattr__ (line 117) | def __delattr__(self, name): method load_weights (line 123) | def load_weights( method save_weights (line 209) | def save_weights(self, file: str): method is_module (line 227) | def is_module(value): method valid_child_filter (line 231) | def valid_child_filter(module, key, value): method valid_parameter_filter (line 235) | def valid_parameter_filter(module, key, value): method trainable_parameter_filter (line 239) | def trainable_parameter_filter(module, key, value): method filter_and_map (line 245) | def filter_and_map( method parameters (line 280) | def parameters(self): method trainable_parameters (line 285) | def trainable_parameters(self): method children (line 290) | def children(self): method leaf_modules (line 296) | def leaf_modules(self): method update (line 304) | def update(self, parameters: dict, strict: bool = True) -> Module: method apply (line 366) | def apply( method update_modules (line 389) | def update_modules(self, modules: dict, strict: bool = True) -> Module: method apply_to_modules (line 412) | def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> ... method modules (line 435) | def modules(self): method named_modules (line 445) | def named_modules(self): method _validate_keys (line 456) | def _validate_keys(self, keys, strict): method freeze (line 464) | def freeze( method unfreeze (line 519) | def unfreeze( method _set_training_mode (line 569) | def _set_training_mode(self, mode: bool) -> None: method train (line 572) | def train(self, mode: bool = True) -> Module: method eval (line 590) | def eval(self) -> Module: method set_dtype (line 597) | def set_dtype( function _update_modules (line 619) | def _update_modules(dst, modules, strict): function _unwrap (line 649) | def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): FILE: python/mlx/nn/layers/containers.py class Sequential (line 6) | class Sequential(Module): method __init__ (line 17) | def __init__(self, *modules): method __call__ (line 21) | def __call__(self, x): FILE: python/mlx/nn/layers/convolution.py class Conv1d (line 10) | class Conv1d(Module): method __init__ (line 34) | def __init__( method _extra_repr (line 67) | def _extra_repr(self): method __call__ (line 76) | def __call__(self, x): class Conv2d (line 85) | class Conv2d(Module): method __init__ (line 110) | def __init__( method _extra_repr (line 147) | def _extra_repr(self): method __call__ (line 156) | def __call__(self, x): class Conv3d (line 165) | class Conv3d(Module): method __init__ (line 189) | def __init__( method _extra_repr (line 220) | def _extra_repr(self): method __call__ (line 228) | def __call__(self, x): FILE: python/mlx/nn/layers/convolution_transpose.py class ConvTranspose1d (line 10) | class ConvTranspose1d(Module): method __init__ (line 34) | def __init__( method _extra_repr (line 61) | def _extra_repr(self): method __call__ (line 70) | def __call__(self, x): class ConvTranspose2d (line 84) | class ConvTranspose2d(Module): method __init__ (line 109) | def __init__( method _extra_repr (line 140) | def _extra_repr(self): method __call__ (line 149) | def __call__(self, x): class ConvTranspose3d (line 163) | class ConvTranspose3d(Module): method __init__ (line 189) | def __init__( method _extra_repr (line 222) | def _extra_repr(self): method __call__ (line 231) | def __call__(self, x): FILE: python/mlx/nn/layers/distributed.py function sum_gradients (line 15) | def sum_gradients(group): function _split (line 30) | def _split(weight, segments, axis): function _shard (line 40) | def _shard( function _all_to_sharded (line 84) | def _all_to_sharded(segments): function _sharded_to_all (line 96) | def _sharded_to_all(segments): function _check_sharding (line 108) | def _check_sharding(sharding): function shard_inplace (line 118) | def shard_inplace( function shard_linear (line 158) | def shard_linear( class AllToShardedLinear (line 193) | class AllToShardedLinear(Module): method __init__ (line 209) | def __init__( method _extra_repr (line 240) | def _extra_repr(self) -> str: method __call__ (line 246) | def __call__(self, x: mx.array) -> mx.array: method from_linear (line 258) | def from_linear( class ShardedToAllLinear (line 274) | class ShardedToAllLinear(Module): method __init__ (line 293) | def __init__( method _extra_repr (line 324) | def _extra_repr(self) -> str: method __call__ (line 330) | def __call__(self, x: mx.array) -> mx.array: method from_linear (line 341) | def from_linear( class QuantizedAllToShardedLinear (line 357) | class QuantizedAllToShardedLinear(Module): method __init__ (line 381) | def __init__( method unfreeze (line 425) | def unfreeze(self, *args, **kwargs): method _extra_repr (line 431) | def _extra_repr(self) -> str: method __call__ (line 440) | def __call__(self, x: mx.array) -> mx.array: method from_quantized_linear (line 459) | def from_quantized_linear( class QuantizedShardedToAllLinear (line 490) | class QuantizedShardedToAllLinear(Module): method __init__ (line 516) | def __init__( method unfreeze (line 560) | def unfreeze(self, *args, **kwargs): method _extra_repr (line 566) | def _extra_repr(self) -> str: method __call__ (line 574) | def __call__(self, x: mx.array) -> mx.array: method from_quantized_linear (line 591) | def from_quantized_linear( FILE: python/mlx/nn/layers/dropout.py class Dropout (line 7) | class Dropout(Module): method __init__ (line 18) | def __init__(self, p: float = 0.5): method _extra_repr (line 26) | def _extra_repr(self) -> str: method __call__ (line 29) | def __call__(self, x: mx.array) -> mx.array: class Dropout2d (line 38) | class Dropout2d(Module): method __init__ (line 61) | def __init__(self, p: float = 0.5): method _extra_repr (line 69) | def _extra_repr(self) -> str: method __call__ (line 72) | def __call__(self, x: mx.array) -> mx.array: class Dropout3d (line 91) | class Dropout3d(Module): method __init__ (line 110) | def __init__(self, p: float = 0.5): method _extra_repr (line 118) | def _extra_repr(self) -> str: method __call__ (line 121) | def __call__(self, x: mx.array) -> mx.array: FILE: python/mlx/nn/layers/embedding.py class Embedding (line 11) | class Embedding(Module): method __init__ (line 23) | def __init__(self, num_embeddings: int, dims: int): method _extra_repr (line 28) | def _extra_repr(self): method __call__ (line 31) | def __call__(self, x): method as_linear (line 34) | def as_linear(self, x): method to_quantized (line 43) | def to_quantized( FILE: python/mlx/nn/layers/linear.py class Identity (line 11) | class Identity(Module): method __init__ (line 19) | def __init__(self, *args: Any, **kwargs: Any) -> None: method __call__ (line 22) | def __call__(self, x: mx.array) -> mx.array: class Linear (line 26) | class Linear(Module): method __init__ (line 48) | def __init__(self, input_dims: int, output_dims: int, bias: bool = Tru... method _extra_repr (line 63) | def _extra_repr(self) -> str: method __call__ (line 66) | def __call__(self, x: mx.array) -> mx.array: method to_quantized (line 73) | def to_quantized( class Bilinear (line 111) | class Bilinear(Module): method __init__ (line 135) | def __init__( method _extra_repr (line 152) | def _extra_repr(self) -> str: method __call__ (line 159) | def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: FILE: python/mlx/nn/layers/normalization.py class InstanceNorm (line 9) | class InstanceNorm(Module): method __init__ (line 42) | def __init__( method _extra_repr (line 55) | def _extra_repr(self): method __call__ (line 58) | def __call__(self, x: mx.array) -> mx.array: class LayerNorm (line 69) | class LayerNorm(Module): method __init__ (line 93) | def __init__( method _extra_repr (line 104) | def _extra_repr(self): method __call__ (line 107) | def __call__(self, x): class RMSNorm (line 113) | class RMSNorm(Module): method __init__ (line 134) | def __init__(self, dims: int, eps: float = 1e-5): method _extra_repr (line 139) | def _extra_repr(self): method __call__ (line 142) | def __call__(self, x): class GroupNorm (line 146) | class GroupNorm(Module): method __init__ (line 176) | def __init__( method _extra_repr (line 193) | def _extra_repr(self): method _pytorch_compatible_group_norm (line 199) | def _pytorch_compatible_group_norm(self, x): method _group_norm (line 215) | def _group_norm(self, x): method __call__ (line 230) | def __call__(self, x): class BatchNorm (line 240) | class BatchNorm(Module): method __init__ (line 281) | def __init__( method unfreeze (line 305) | def unfreeze(self, *args, **kwargs): method _extra_repr (line 311) | def _extra_repr(self): method _calc_stats (line 318) | def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: method __call__ (line 336) | def __call__(self, x: mx.array) -> mx.array: FILE: python/mlx/nn/layers/pooling.py function _value_or_list (line 11) | def _value_or_list(x, n, msg): function _non_overlapping_sliding_windows (line 23) | def _non_overlapping_sliding_windows(x, shape, window_shape): function _sliding_windows (line 39) | def _sliding_windows(x, window_shape, window_strides): class _Pool (line 84) | class _Pool(Module): method __init__ (line 85) | def __init__(self, pooling_function, kernel_size, stride, padding, pad... method _extra_repr (line 95) | def _extra_repr(self): method __call__ (line 102) | def __call__(self, x): class _Pool1d (line 113) | class _Pool1d(_Pool): method __init__ (line 114) | def __init__( class _Pool2d (line 137) | class _Pool2d(_Pool): method __init__ (line 138) | def __init__( class _Pool3d (line 161) | class _Pool3d(_Pool): method __init__ (line 162) | def __init__( class MaxPool1d (line 185) | class MaxPool1d(_Pool1d): method __init__ (line 207) | def __init__( class AvgPool1d (line 216) | class AvgPool1d(_Pool1d): method __init__ (line 238) | def __init__( class MaxPool2d (line 247) | class MaxPool2d(_Pool2d): method __init__ (line 276) | def __init__( class AvgPool2d (line 285) | class AvgPool2d(_Pool2d): method __init__ (line 314) | def __init__( class MaxPool3d (line 323) | class MaxPool3d(_Pool3d): method __init__ (line 353) | def __init__( class AvgPool3d (line 362) | class AvgPool3d(_Pool3d): method __init__ (line 392) | def __init__( FILE: python/mlx/nn/layers/positional_encoding.py class RoPE (line 10) | class RoPE(Module): method __init__ (line 30) | def __init__( method _extra_repr (line 43) | def _extra_repr(self): method __call__ (line 46) | def __call__(self, x, offset: int = 0): class SinusoidalPositionalEncoding (line 57) | class SinusoidalPositionalEncoding(Module): method __init__ (line 77) | def __init__( method __call__ (line 101) | def __call__(self, x): class ALiBi (line 117) | class ALiBi(Module): method create_alibi_matrix (line 119) | def create_alibi_matrix( method create_alibi_slope (line 136) | def create_alibi_slope(num_heads, dtype): method __call__ (line 152) | def __call__(self, attention_scores, offset=0, mask=None): FILE: python/mlx/nn/layers/quantized.py function _defaults_for_mode (line 11) | def _defaults_for_mode(mode, group_size, bits): function quantize (line 22) | def quantize( class QuantizedEmbedding (line 98) | class QuantizedEmbedding(Module): method __init__ (line 117) | def __init__( method __call__ (line 144) | def __call__(self, x): method as_linear (line 155) | def as_linear(self, x): method _extra_repr (line 173) | def _extra_repr(self): method from_embedding (line 180) | def from_embedding( class QuantizedLinear (line 200) | class QuantizedLinear(Module): method __init__ (line 223) | def __init__( method _extra_repr (line 257) | def _extra_repr(self): method __call__ (line 265) | def __call__(self, x): method from_linear (line 281) | def from_linear( class QQLinear (line 305) | class QQLinear(Module): method __init__ (line 338) | def __init__( method _extra_repr (line 360) | def _extra_repr(self): method quantize (line 369) | def quantize(self): method dequantize (line 379) | def dequantize(self): method _set_training_mode (line 391) | def _set_training_mode(self, mode: bool): method __call__ (line 399) | def __call__(self, x): method from_linear (line 411) | def from_linear( FILE: python/mlx/nn/layers/recurrent.py class RNN (line 11) | class RNN(Module): method __init__ (line 39) | def __init__( method _extra_repr (line 68) | def _extra_repr(self): method __call__ (line 75) | def __call__(self, x, hidden=None): class GRU (line 93) | class GRU(Module): method __init__ (line 123) | def __init__( method _extra_repr (line 150) | def _extra_repr(self): method __call__ (line 156) | def __call__(self, x, hidden=None): class LSTM (line 201) | class LSTM(Module): method __init__ (line 234) | def __init__( method _extra_repr (line 256) | def _extra_repr(self): method __call__ (line 262) | def __call__(self, x, hidden=None, cell=None): FILE: python/mlx/nn/layers/transformer.py class MultiHeadAttention (line 15) | class MultiHeadAttention(Module): method __init__ (line 48) | def __init__( method __call__ (line 79) | def __call__(self, queries, keys, values, mask=None): method create_additive_causal_mask (line 96) | def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): class TransformerEncoderLayer (line 103) | class TransformerEncoderLayer(Module): method __init__ (line 104) | def __init__( method __call__ (line 125) | def __call__(self, x, mask): class TransformerEncoder (line 153) | class TransformerEncoder(Module): method __init__ (line 154) | def __init__( method __call__ (line 175) | def __call__(self, x, mask): class TransformerDecoderLayer (line 182) | class TransformerDecoderLayer(Module): method __init__ (line 183) | def __init__( method __call__ (line 207) | def __call__(self, x, memory, x_mask, memory_mask): class TransformerDecoder (line 244) | class TransformerDecoder(Module): method __init__ (line 245) | def __init__( method __call__ (line 266) | def __call__(self, x, memory, x_mask, memory_mask): class Transformer (line 273) | class Transformer(Module): method __init__ (line 314) | def __init__( method __call__ (line 352) | def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): FILE: python/mlx/nn/layers/upsample.py function _scaled_indices (line 12) | def _scaled_indices(N, scale, align_corners, dim, ndims): function _nearest_indices (line 27) | def _nearest_indices(N, scale, dim, ndims): function _linear_indices (line 40) | def _linear_indices(N, scale, align_corners, dim, ndims): function _cubic_indices (line 54) | def _cubic_indices(N, scale, align_corners, dim, ndims): function upsample_nearest (line 92) | def upsample_nearest(x: mx.array, scale_factor: Tuple): function _interpolate (line 122) | def _interpolate( function upsample_linear (line 148) | def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: boo... function upsample_cubic (line 157) | def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool... class Upsample (line 166) | class Upsample(Module): method __init__ (line 228) | def __init__( method _extra_repr (line 244) | def _extra_repr(self) -> str: method __call__ (line 250) | def __call__(self, x: mx.array) -> mx.array: FILE: python/mlx/nn/losses.py function _reduce (line 11) | def _reduce(loss: mx.array, reduction: Reduction = "none"): function cross_entropy (line 23) | def cross_entropy( function binary_cross_entropy (line 120) | def binary_cross_entropy( function l1_loss (line 186) | def l1_loss( function mse_loss (line 211) | def mse_loss( function nll_loss (line 236) | def nll_loss( function gaussian_nll_loss (line 257) | def gaussian_nll_loss( function kl_div_loss (line 312) | def kl_div_loss( function smooth_l1_loss (line 339) | def smooth_l1_loss( function triplet_loss (line 386) | def triplet_loss( function hinge_loss (line 428) | def hinge_loss( function huber_loss (line 453) | def huber_loss( function log_cosh_loss (line 490) | def log_cosh_loss( function cosine_similarity_loss (line 522) | def cosine_similarity_loss( function margin_ranking_loss (line 558) | def margin_ranking_loss( FILE: python/mlx/nn/utils.py function value_and_grad (line 12) | def value_and_grad(model: Module, fn: Callable): function checkpoint (line 41) | def checkpoint(module: Module, fn: Optional[Callable] = None): function _extract_info (line 74) | def _extract_info(flat): function _group_by_size (line 82) | def _group_by_size(keys, sizes, itemsize, communication_size): function average_gradients (line 99) | def average_gradients( function _clip_grads_fsdp (line 176) | def _clip_grads_fsdp(grads_slice, max_norm, group=None): function fsdp_apply_gradients (line 186) | def fsdp_apply_gradients( FILE: python/mlx/optimizers/optimizers.py class Optimizer (line 10) | class Optimizer: method __init__ (line 15) | def __init__(self, schedulers=None): method update (line 20) | def update(self, model: Module, gradients: dict): method init (line 31) | def init(self, parameters: dict): method init_single (line 75) | def init_single(self, parameter: mx.array, state: dict): method apply_gradients (line 85) | def apply_gradients(self, gradients: dict, parameters: dict): method apply_single (line 111) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... method state (line 122) | def state(self): method state (line 127) | def state(self, state: dict): method step (line 132) | def step(self): method learning_rate (line 136) | def learning_rate(self): method learning_rate (line 140) | def learning_rate(self, learning_rate: Union[float, mx.array]): method _maybe_schedule (line 143) | def _maybe_schedule( class MultiOptimizer (line 157) | class MultiOptimizer(Optimizer): method __init__ (line 172) | def __init__(self, optimizers, filters: list = []): method _split_dictionary (line 184) | def _split_dictionary(self, gradients: dict): method init (line 198) | def init(self, parameters: dict): method apply_gradients (line 202) | def apply_gradients(self, gradients: dict, parameters: dict): method state (line 209) | def state(self): method state (line 213) | def state(self, state: dict): method learning_rate (line 221) | def learning_rate(self): method learning_rate (line 225) | def learning_rate(self, learning_rate: Union[float, mx.array]): class SGD (line 230) | class SGD(Optimizer): method __init__ (line 248) | def __init__( method init_single (line 268) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 272) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class RMSprop (line 297) | class RMSprop(Optimizer): method __init__ (line 315) | def __init__( method init_single (line 336) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 340) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Adagrad (line 353) | class Adagrad(Optimizer): method __init__ (line 372) | def __init__( method init_single (line 387) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 391) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class AdaDelta (line 403) | class AdaDelta(Optimizer): method __init__ (line 425) | def __init__( method init_single (line 445) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 450) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Adam (line 470) | class Adam(Optimizer): method __init__ (line 493) | def __init__( method init_single (line 507) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 512) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class AdamW (line 538) | class AdamW(Adam): method __init__ (line 564) | def __init__( method apply_single (line 580) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Adamax (line 591) | class Adamax(Adam): method __init__ (line 615) | def __init__( method init_single (line 627) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 632) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Lion (line 650) | class Lion(Optimizer): method __init__ (line 677) | def __init__( method init_single (line 689) | def init_single(self, parameter: mx.array, state: dict): method apply_single (line 693) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Adafactor (line 708) | class Adafactor(Optimizer): method __init__ (line 742) | def __init__( method init_single (line 766) | def init_single(self, parameter: mx.array, state: dict): method _compute_rms (line 779) | def _compute_rms(self, inputs): method _compute_learning_rate (line 782) | def _compute_learning_rate(self, step, parameter_rms): method _approximate_exp_moving_avg (line 795) | def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col): method apply_single (line 804) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... class Muon (line 851) | class Muon(Optimizer): method __init__ (line 876) | def __init__( method init_single (line 892) | def init_single(self, parameter: mx.array, state: dict): method _zeropower_via_newtonschulz5 (line 896) | def _zeropower_via_newtonschulz5(self, X, steps: int): method apply_single (line 917) | def apply_single(self, gradient: mx.array, parameter: mx.array, state:... function clip_grad_norm (line 951) | def clip_grad_norm(grads, max_norm): FILE: python/mlx/optimizers/schedulers.py function exponential_decay (line 9) | def exponential_decay(init: float, decay_rate: float) -> Callable: function step_decay (line 34) | def step_decay(init: float, decay_rate: float, step_size: int) -> Callable: function cosine_decay (line 61) | def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Cal... function join_schedules (line 91) | def join_schedules(schedules: List[Callable], boundaries: List[int]) -> ... function linear_schedule (line 131) | def linear_schedule(init: float, end: float, steps: int) -> Callable: FILE: python/mlx/utils.py function tree_map (line 7) | def tree_map( function tree_map_with_path (line 61) | def tree_map_with_path( function tree_flatten (line 117) | def tree_flatten( function tree_unflatten (line 193) | def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -... function tree_reduce (line 243) | def tree_reduce(fn, tree, initializer=None, is_leaf=None): function tree_merge (line 285) | def tree_merge(tree_a, tree_b, merge_fn=None): FILE: python/src/array.cpp class ArrayAt (line 27) | class ArrayAt { method ArrayAt (line 29) | ArrayAt(mx::array x) : x_(std::move(x)) {} method ArrayAt (line 30) | ArrayAt& set_indices(nb::object indices) { method check_initialized (line 35) | void check_initialized() { method add (line 42) | mx::array add(const ScalarOrArray& v) { method subtract (line 46) | mx::array subtract(const ScalarOrArray& v) { method multiply (line 50) | mx::array multiply(const ScalarOrArray& v) { method divide (line 54) | mx::array divide(const ScalarOrArray& v) { method maximum (line 58) | mx::array maximum(const ScalarOrArray& v) { method minimum (line 62) | mx::array minimum(const ScalarOrArray& v) { class ArrayPythonIterator (line 73) | class ArrayPythonIterator { method ArrayPythonIterator (line 75) | ArrayPythonIterator(mx::array x) : idx_(0), x_(std::move(x)) { method next (line 81) | mx::array next() { function init_array (line 99) | void init_array(nb::module_& m) { FILE: python/src/buffer.h type buffer_info (line 59) | struct buffer_info { function getbuffer (line 87) | inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { function releasebuffer (line 122) | inline void releasebuffer(PyObject*, Py_buffer* view) { FILE: python/src/constants.cpp function init_constants (line 8) | void init_constants(nb::module_& m) { FILE: python/src/convert.cpp type PyScalarT (line 10) | enum PyScalarT { type nanobind (line 17) | namespace nanobind { type ndarray_traits (line 19) | struct ndarray_traits { function check_shape_dim (line 28) | int check_shape_dim(int64_t dim) { function nd_array_to_mlx_contiguous (line 37) | mx::array nd_array_to_mlx_contiguous( function nd_array_to_mlx (line 47) | mx::array nd_array_to_mlx( function mlx_to_nd_array_impl (line 110) | nb::ndarray mlx_to_nd_array_impl( function mlx_to_nd_array (line 128) | nb::ndarray mlx_to_nd_array(const mx::array& a) { function mlx_to_np_array (line 163) | nb::ndarray mlx_to_np_array(const mx::array& a) { function mlx_to_dlpack (line 167) | nb::ndarray<> mlx_to_dlpack(const mx::array& a) { function to_scalar (line 171) | nb::object to_scalar(mx::array& a) { function to_list (line 215) | nb::list to_list(mx::array& a, size_t index, int dim) { function tolist (line 229) | nb::object tolist(mx::array& a) { function fill_vector (line 272) | void fill_vector(T list, std::vector& vals) { function PyScalarT (line 285) | PyScalarT validate_shape( function get_shape (line 355) | void get_shape(T list, mx::Shape& shape) { function array_from_list_impl (line 374) | mx::array array_from_list_impl( function array_from_list_impl (line 440) | mx::array array_from_list_impl(T pl, std::optional dtype) { function array_from_list (line 462) | mx::array array_from_list(nb::list pl, std::optional dtype) { function array_from_list (line 466) | mx::array array_from_list(nb::tuple pl, std::optional dtype) { function create_array (line 470) | mx::array create_array(ArrayInitType v, std::optional t) { FILE: python/src/convert.h function namespace (line 15) | namespace nanobind { FILE: python/src/cuda.cpp function init_cuda (line 10) | void init_cuda(nb::module_& m) { FILE: python/src/device.cpp function init_device (line 17) | void init_device(nb::module_& m) { FILE: python/src/distributed.cpp function init_distributed (line 19) | void init_distributed(nb::module_& parent_module) { FILE: python/src/export.cpp function validate_and_extract_inputs (line 23) | std::pair validate_and_extract_inputs( class PyFunctionExporter (line 71) | class PyFunctionExporter { method PyFunctionExporter (line 73) | PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep) method PyFunctionExporter (line 78) | PyFunctionExporter(const PyFunctionExporter&) = delete; method PyFunctionExporter (line 79) | PyFunctionExporter& operator=(const PyFunctionExporter&) = delete; method PyFunctionExporter (line 80) | PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete; method PyFunctionExporter (line 81) | PyFunctionExporter(PyFunctionExporter&& other) method close (line 84) | void close() { function py_function_exporter_tp_traverse (line 98) | int py_function_exporter_tp_traverse( function wrap_export_function (line 115) | auto wrap_export_function(nb::callable fun) { function init_export (line 134) | void init_export(nb::module_& m) { FILE: python/src/fast.cpp type PyCustomKernelFunction (line 22) | struct PyCustomKernelFunction { method PyCustomKernelFunction (line 23) | PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const ch... function init_fast (line 80) | void init_fast(nb::module_& parent_module) { FILE: python/src/fft.cpp function init_fft (line 17) | void init_fft(nb::module_& parent_module) { FILE: python/src/indexing.cpp function is_none_slice (line 11) | bool is_none_slice(const nb::slice& in_slice) { function is_index_scalar (line 18) | bool is_index_scalar(const nb::object& obj) { function safe_to_int32 (line 35) | int safe_to_int32(nb::object obj) { function get_slice_int (line 48) | int get_slice_int(nb::object obj, int default_val) { function get_slice_params (line 58) | void get_slice_params( function get_int_index (line 77) | mx::array get_int_index(nb::object idx, int axis_size) { function is_valid_index_type (line 84) | bool is_valid_index_type(const nb::object& obj) { function mlx_get_item_slice (line 90) | mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_s... function mlx_get_item_array (line 111) | mx::array mlx_get_item_array(const mx::array& src, const mx::array& indi... function mlx_get_item_int (line 127) | mx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) { function mlx_gather_nd (line 139) | mx::array mlx_gather_nd( function mlx_expand_ellipsis (line 213) | auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entrie... function mlx_get_item_nd (line 274) | mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { function mlx_get_item (line 459) | mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { function mlx_scatter_args_int (line 479) | std::tuple, mx::array, std::vector> function squeeze_leading_singletons (line 504) | mx::array squeeze_leading_singletons(const mx::array& in) { function mlx_scatter_args_array (line 513) | std::tuple, mx::array, std::vector> function mlx_scatter_args_slice (line 535) | std::tuple, mx::array, std::vector> function mlx_scatter_args_nd (line 588) | std::tuple, mx::array, std::vector> function mlx_compute_scatter_args (line 770) | std::tuple, mx::array, std::vector> function mlx_compute_slice_update_args (line 794) | std::tuple, mx::Shape, mx::Shape, mx::Shape> function extract_boolean_mask (line 931) | std::optional extract_boolean_mask(const nb::object& obj) { function mlx_set_item (line 954) | void mlx_set_item( function mlx_add_item (line 982) | mx::array mlx_add_item( function mlx_subtract_item (line 1000) | mx::array mlx_subtract_item( function mlx_multiply_item (line 1018) | mx::array mlx_multiply_item( function mlx_divide_item (line 1036) | mx::array mlx_divide_item( function mlx_maximum_item (line 1054) | mx::array mlx_maximum_item( function mlx_minimum_item (line 1072) | mx::array mlx_minimum_item( FILE: python/src/linalg.cpp function init_linalg (line 18) | void init_linalg(nb::module_& parent_module) { FILE: python/src/load.cpp function is_str_or_path (line 26) | bool is_str_or_path(nb::object obj) { function is_istream_object (line 34) | bool is_istream_object(const nb::object& file) { function is_ostream_object (line 39) | bool is_ostream_object(const nb::object& file) { function is_zip_file (line 44) | bool is_zip_file(const nb::module_& zipfile, const nb::object& file) { class ZipFileWrapper (line 54) | class ZipFileWrapper { method ZipFileWrapper (line 56) | ZipFileWrapper( method namelist (line 72) | std::vector namelist() const { method open (line 76) | nb::object open(const std::string& key, char mode = 'r') { class PyFileReader (line 98) | class PyFileReader : public mx::io::Reader { method PyFileReader (line 100) | PyFileReader(nb::object file) method is_open (line 115) | bool is_open() const override { method good (line 124) | bool good() const override { method tell (line 133) | size_t tell() override { method seek (line 142) | void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) method read (line 148) | void read(char* data, size_t n) override { method read (line 153) | void read(char* data, size_t n, size_t offset) override { method label (line 159) | std::string label() const override { method _read (line 164) | void _read(char* data, size_t n) { function mlx_load_safetensor_helper (line 183) | std::pair< function mlx_load_gguf_helper (line 206) | mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { function mlx_load_npz_helper (line 215) | std::unordered_map mlx_load_npz_helper( function mlx_load_npy_helper (line 258) | mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { function LoadOutputTypes (line 275) | LoadOutputTypes mlx_load_helper( class PyFileWriter (line 328) | class PyFileWriter : public mx::io::Writer { method PyFileWriter (line 330) | PyFileWriter(nb::object file) method is_open (line 345) | bool is_open() const override { method good (line 354) | bool good() const override { method tell (line 363) | size_t tell() override { method seek (line 372) | void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) method write (line 378) | void write(const char* data, size_t n) override { method label (line 393) | std::string label() const override { function mlx_save_helper (line 404) | void mlx_save_helper(nb::object file, mx::array a) { function mlx_savez_helper (line 423) | void mlx_savez_helper( function mlx_save_safetensor_helper (line 478) | void mlx_save_safetensor_helper( function mlx_save_gguf_helper (line 513) | void mlx_save_gguf_helper( FILE: python/src/memory.cpp function init_memory (line 10) | void init_memory(nb::module_& m) { FILE: python/src/metal.cpp function DEPRECATE (line 20) | bool DEPRECATE(const char* old_fn, const char* new_fn) { function init_metal (line 28) | void init_metal(nb::module_& m) { FILE: python/src/mlx.cpp function NB_MODULE (line 27) | NB_MODULE(core, m) { FILE: python/src/mlx_func.cpp type gc_func (line 8) | struct gc_func { function gc_func_tp_traverse (line 21) | int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { function gc_func_tp_clear (line 31) | int gc_func_tp_clear(PyObject* self) { function PyObject (line 37) | PyObject* gc_func_get_doc(PyObject* self, void*) { function PyObject (line 41) | PyObject* gc_func_get_sig(PyObject* self, void*) { function PyObject (line 45) | PyObject* gc_func_vectorcall( function gc_func_dealloc (line 53) | void gc_func_dealloc(PyObject* self) { function PyObject (line 72) | static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { function mlx_func (line 98) | nb::callable mlx_func( function init_mlx_func (line 111) | void init_mlx_func(nb::module_& m) { FILE: python/src/mlx_func.h function callable (line 19) | callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) { function callable (line 27) | callable FILE: python/src/ops.cpp function scalar_to_dtype (line 28) | mx::Dtype scalar_to_dtype(Scalar s) { function scalar_to_double (line 38) | double scalar_to_double(Scalar s) { function init_ops (line 48) | void init_ops(nb::module_& m) { FILE: python/src/random.cpp class PyKeySequence (line 19) | class PyKeySequence { method PyKeySequence (line 21) | PyKeySequence() { method seed (line 27) | void seed(uint64_t seed) { method next (line 31) | mx::array next() { function PyKeySequence (line 55) | PyKeySequence& default_key() { method PyKeySequence (line 21) | PyKeySequence() { method seed (line 27) | void seed(uint64_t seed) { method next (line 31) | mx::array next() { function init_random (line 61) | void init_random(nb::module_& parent_module) { FILE: python/src/small_vector.h function NAMESPACE_BEGIN (line 9) | NAMESPACE_BEGIN(NB_NAMESPACE) FILE: python/src/stream.cpp class PyStreamContext (line 18) | class PyStreamContext { method PyStreamContext (line 20) | PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) { method enter (line 28) | void enter() { method exit (line 32) | void exit() { function init_stream (line 44) | void init_stream(nb::module_& m) { FILE: python/src/transforms.cpp function type_name_str (line 36) | inline std::string type_name_str(const nb::handle& o) { function validate_argnums_argnames (line 40) | auto validate_argnums_argnames( function py_value_and_grad (line 69) | auto py_value_and_grad( function py_vmap (line 309) | auto py_vmap( type PyCompiledFun (line 409) | struct PyCompiledFun { type AttachedData (line 418) | struct AttachedData { method AttachedData (line 422) | AttachedData(nb::object output_structure_, int num_outputs_) method PyCompiledFun (line 426) | PyCompiledFun( method PyCompiledFun (line 437) | PyCompiledFun(const PyCompiledFun&) = delete; method PyCompiledFun (line 438) | PyCompiledFun& operator=(const PyCompiledFun&) = delete; method PyCompiledFun (line 439) | PyCompiledFun& operator=(PyCompiledFun&& other) = delete; method PyCompiledFun (line 440) | PyCompiledFun(PyCompiledFun&& other) method call_impl (line 449) | nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { class PyCheckpointedFun (line 590) | class PyCheckpointedFun { method PyCheckpointedFun (line 592) | PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {} type InnerFunction (line 599) | struct InnerFunction { method InnerFunction (line 604) | InnerFunction( method call_impl (line 630) | nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { class PyCustomFunction (line 681) | class PyCustomFunction { method PyCustomFunction (line 683) | PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {} type InnerFunction (line 689) | struct InnerFunction { method InnerFunction (line 694) | InnerFunction( type InnerVJPFunction (line 723) | struct InnerVJPFunction { method InnerVJPFunction (line 728) | InnerVJPFunction( type InnerJVPFunction (line 771) | struct InnerJVPFunction { method InnerJVPFunction (line 775) | InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure) type InnerVmapFunction (line 826) | struct InnerVmapFunction { method InnerVmapFunction (line 830) | InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure) method call_impl (line 896) | nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { method PyCustomFunction (line 925) | PyCustomFunction& set_vjp(nb::callable vjp_fun) { method PyCustomFunction (line 930) | PyCustomFunction& set_jvp(nb::callable jvp_fun) { method PyCustomFunction (line 935) | PyCustomFunction& set_vmap(nb::callable vmap_fun) { method reset (line 939) | void reset() { method make_vjp_function (line 955) | std::optional make_vjp_function( method make_jvp_function (line 965) | std::optional make_jvp_function( method make_vmap_function (line 974) | std::optional make_vmap_function( function py_custom_function_tp_traverse (line 989) | int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void... function py_custom_function_tp_clear (line 1012) | int py_custom_function_tp_clear(PyObject* self) { function init_transforms (line 1022) | void init_transforms(nb::module_& m) { FILE: python/src/trees.cpp function validate_subtrees (line 6) | void validate_subtrees(const std::vector& subtrees) { function tree_map (line 17) | nb::object tree_map( function tree_map (line 86) | nb::object tree_map( function tree_visit (line 94) | void tree_visit( function tree_visit (line 153) | void tree_visit(nb::handle tree, std::function visitor) { function tree_visit_update (line 173) | void tree_visit_update( function tree_fill (line 212) | void tree_fill(nb::object& tree, const std::vector& values) { function tree_replace (line 219) | void tree_replace( function tree_flatten (line 236) | std::vector tree_flatten(nb::handle tree, bool strict /* = tr... function tree_unflatten (line 251) | nb::object tree_unflatten( function structure_sentinel (line 264) | nb::object structure_sentinel() { function tree_flatten_with_structure (line 277) | std::pair, nb::object> tree_flatten_with_structure( function tree_unflatten_from_structure (line 299) | nb::object tree_unflatten_from_structure( FILE: python/src/utils.cpp function to_array (line 8) | mx::array to_array( function to_arrays (line 50) | std::pair to_arrays( function to_array_with_accessor (line 87) | mx::array to_array_with_accessor(nb::object obj) { FILE: python/src/utils.h function std (line 31) | inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { function is_comparable_with_array (line 44) | inline bool is_comparable_with_array(const ScalarOrArray& v) { function nb (line 57) | inline nb::handle get_handle_of_object(const ScalarOrArray& v) { function throw_invalid_operation (line 61) | inline void throw_invalid_operation( FILE: python/tests/mlx_distributed_tests.py class MLXDistributedCommonTestCase (line 10) | class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): method test_average_gradients (line 11) | def test_average_gradients(self): method test_all_reduce (line 53) | def test_all_reduce(self): method test_donation (line 95) | def test_donation(self): method test_shard_linear (line 113) | def test_shard_linear(self): method test_shard_predicate (line 266) | def test_shard_predicate(self): method test_all_gather (line 309) | def test_all_gather(self): FILE: python/tests/mlx_tests.py class MLXTestRunner (line 19) | class MLXTestRunner(unittest.TestProgram): method __init__ (line 20) | def __init__(self, *args, **kwargs): method createTests (line 23) | def createTests(self, *args, **kwargs): class MLXTestCase (line 55) | class MLXTestCase(unittest.TestCase): method is_apple_silicon (line 57) | def is_apple_silicon(self): method setUp (line 60) | def setUp(self): method tearDown (line 67) | def tearDown(self): method assertCmpNumpy (line 71) | def assertCmpNumpy( method assertEqualArray (line 92) | def assertEqualArray( FILE: python/tests/mpi_test_distributed.py class TestMPIDistributed (line 8) | class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestC... method setUpClass (line 10) | def setUpClass(cls): method test_groups (line 15) | def test_groups(self): method test_all_reduce_extra (line 31) | def test_all_reduce_extra(self): method test_all_gather_extra (line 73) | def test_all_gather_extra(self): method test_mixed (line 93) | def test_mixed(self): method test_send_recv (line 115) | def test_send_recv(self): FILE: python/tests/nccl_test_distributed.py class TestNCCLDistributed (line 10) | class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTest... method setUpClass (line 12) | def setUpClass(cls): method test_sum_scatter (line 17) | def test_sum_scatter(self): method test_groups (line 52) | def test_groups(self): method test_all_reduce_split (line 68) | def test_all_reduce_split(self): method test_all_gather_split (line 109) | def test_all_gather_split(self): method test_fsdp_apply_gradients (line 119) | def test_fsdp_apply_gradients(self): method test_fsdp_ddp_apply_gradients (line 199) | def test_fsdp_ddp_apply_gradients(self): method test_fsdp_peak_memory (line 300) | def test_fsdp_peak_memory(self): FILE: python/tests/ring_test_distributed.py class TestRingDistributed (line 8) | class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTest... method setUpClass (line 10) | def setUpClass(cls): method test_groups (line 15) | def test_groups(self): method test_all_reduce_extra (line 27) | def test_all_reduce_extra(self): method test_all_gather_extra (line 67) | def test_all_gather_extra(self): method test_send_recv (line 80) | def test_send_recv(self): method test_all_gather_vjp (line 119) | def test_all_gather_vjp(self): FILE: python/tests/test_array.py class TestVersion (line 27) | class TestVersion(mlx_tests.MLXTestCase): method test_version (line 28) | def test_version(self): class TestDtypes (line 36) | class TestDtypes(mlx_tests.MLXTestCase): method test_dtypes (line 37) | def test_dtypes(self): method test_scalar_conversion (line 66) | def test_scalar_conversion(self): method test_finfo (line 96) | def test_finfo(self): method test_iinfo (line 110) | def test_iinfo(self): class TestEquality (line 123) | class TestEquality(mlx_tests.MLXTestCase): method test_array_eq_array (line 124) | def test_array_eq_array(self): method test_array_eq_scalar (line 131) | def test_array_eq_scalar(self): method test_list_equals_array (line 142) | def test_list_equals_array(self): method test_tuple_equals_array (line 152) | def test_tuple_equals_array(self): class TestInequality (line 163) | class TestInequality(mlx_tests.MLXTestCase): method test_array_ne_array (line 164) | def test_array_ne_array(self): method test_array_ne_scalar (line 171) | def test_array_ne_scalar(self): method test_list_not_equals_array (line 184) | def test_list_not_equals_array(self): method test_dlx_device_type (line 194) | def test_dlx_device_type(self): method test_tuple_not_equals_array (line 207) | def test_tuple_not_equals_array(self): method test_obj_inequality_array (line 217) | def test_obj_inequality_array(self): method test_invalid_op_on_array (line 250) | def test_invalid_op_on_array(self): class TestArray (line 274) | class TestArray(mlx_tests.MLXTestCase): method test_array_basics (line 275) | def test_array_basics(self): method test_bool_conversion (line 355) | def test_bool_conversion(self): method test_int_type (line 365) | def test_int_type(self): method test_construction_from_lists (line 377) | def test_construction_from_lists(self): method test_double_keeps_precision (line 433) | def test_double_keeps_precision(self): method test_construction_from_lists_of_mlx_arrays (line 441) | def test_construction_from_lists_of_mlx_arrays(self): method test_init_from_array (line 501) | def test_init_from_array(self): method test_array_repr (line 519) | def test_array_repr(self): method test_array_to_list (line 600) | def test_array_to_list(self): method test_array_np_conversion (line 648) | def test_array_np_conversion(self): method test_array_np_dtype_conversion (line 709) | def test_array_np_dtype_conversion(self): method test_array_from_noncontiguous_np (line 741) | def test_array_from_noncontiguous_np(self): method test_array_np_shape_dim_check (line 748) | def test_array_np_shape_dim_check(self): method test_dtype_promotion (line 756) | def test_dtype_promotion(self): method test_dtype_python_scalar_promotion (line 797) | def test_dtype_python_scalar_promotion(self): method test_array_comparison (line 839) | def test_array_comparison(self): method test_array_neg (line 854) | def test_array_neg(self): method test_array_type_cast (line 859) | def test_array_type_cast(self): method test_array_iteration (line 869) | def test_array_iteration(self): method test_array_pickle (line 881) | def test_array_pickle(self): method test_array_copy (line 903) | def test_array_copy(self): method test_indexing (line 928) | def test_indexing(self): method test_indexing_grad (line 1104) | def test_indexing_grad(self): method test_setitem (line 1117) | def test_setitem(self): method test_array_at (line 1374) | def test_array_at(self): method test_array_at_slice_update_extensive (line 1452) | def test_array_at_slice_update_extensive(self): method test_slice_negative_step (line 1530) | def test_slice_negative_step(self): method test_api (line 1600) | def test_api(self): method test_memoryless_copy (line 1664) | def test_memoryless_copy(self): method test_np_array_conversion_copies_by_default (line 1683) | def test_np_array_conversion_copies_by_default(self): method test_buffer_protocol (line 1689) | def test_buffer_protocol(self): method test_buffer_protocol_ref_counting (line 1765) | def test_buffer_protocol_ref_counting(self): method test_array_view_ref_counting (line 1775) | def test_array_view_ref_counting(self): method test_buffer_protocol_tf (line 1786) | def test_buffer_protocol_tf(self): method test_logical_overloads (line 1848) | def test_logical_overloads(self): method test_inplace (line 1861) | def test_inplace(self): method test_inplace_preserves_ids (line 1913) | def test_inplace_preserves_ids(self): method test_load_from_pickled_np (line 1928) | def test_load_from_pickled_np(self): method test_multi_output_leak (line 1937) | def test_multi_output_leak(self): method test_add_numpy (line 1965) | def test_add_numpy(self): method test_dlpack (line 1972) | def test_dlpack(self): method test_getitem_with_list (line 1986) | def test_getitem_with_list(self): method test_setitem_with_list (line 2007) | def test_setitem_with_list(self): method test_setitem_with_boolean_mask (line 2048) | def test_setitem_with_boolean_mask(self): method test_array_namespace (line 2079) | def test_array_namespace(self): method test_array_namespace_asarray (line 2085) | def test_array_namespace_asarray(self): method test_asarray (line 2099) | def test_asarray(self): method test_to_scalar (line 2137) | def test_to_scalar(self): method test_format (line 2152) | def test_format(self): method test_deep_graphs (line 2165) | def test_deep_graphs(self): method test_siblings_without_eval (line 2193) | def test_siblings_without_eval(self): method test_scalar_integer_conversion_overflow (line 2215) | def test_scalar_integer_conversion_overflow(self): method test_real_imag (line 2223) | def test_real_imag(self): method test_large_indices (line 2232) | def test_large_indices(self): FILE: python/tests/test_autograd.py class TestAutograd (line 10) | class TestAutograd(mlx_tests.MLXTestCase): method test_jvp (line 11) | def test_jvp(self): method test_jvp_comparison_tangent_dtype (line 33) | def test_jvp_comparison_tangent_dtype(self): method test_vjp (line 51) | def test_vjp(self): method test_grad (line 73) | def test_grad(self): method test_grad_trees (line 124) | def test_grad_trees(self): method test_auxiliary_values (line 164) | def test_auxiliary_values(self): method test_grad_kwargs (line 184) | def test_grad_kwargs(self): method test_captured (line 235) | def test_captured(self): method test_stop_gradient (line 259) | def test_stop_gradient(self): method test_update_state (line 283) | def test_update_state(self): method test_scatter_vjp (line 298) | def test_scatter_vjp(self): method test_scatter_add_vjp (line 317) | def test_scatter_add_vjp(self): method test_scatter_max_vjp (line 330) | def test_scatter_max_vjp(self): method test_scatter_min_vjp (line 350) | def test_scatter_min_vjp(self): method test_slice_update_max_vjp (line 370) | def test_slice_update_max_vjp(self): method test_slice_update_min_vjp (line 390) | def test_slice_update_min_vjp(self): method test_slice_update_add_vjp (line 410) | def test_slice_update_add_vjp(self): method test_slice_update_multiply_vjp (line 423) | def test_slice_update_multiply_vjp(self): method test_split_against_slice (line 436) | def test_split_against_slice(self): method test_vjp_types (line 455) | def test_vjp_types(self): method test_power_grad (line 477) | def test_power_grad(self): method test_eval_in_grad (line 490) | def test_eval_in_grad(self): method test_power_grad (line 512) | def test_power_grad(self): method test_cumprod_grad (line 520) | def test_cumprod_grad(self): method test_topk_grad (line 593) | def test_topk_grad(self): method test_custom_function (line 603) | def test_custom_function(self): method test_complex_vjps (line 687) | def test_complex_vjps(self): method test_flatten_unflatten_vjps (line 702) | def test_flatten_unflatten_vjps(self): method test_concatenate_vjps (line 717) | def test_concatenate_vjps(self): method test_matmul_jvps (line 729) | def test_matmul_jvps(self): method test_put_along_axis_grads (line 764) | def test_put_along_axis_grads(self): method test_slice_grads (line 795) | def test_slice_grads(self): method test_leaks (line 832) | def test_leaks(self): method test_grad_with_copies (line 858) | def test_grad_with_copies(self): method test_grad_ids_pre_post (line 869) | def test_grad_ids_pre_post(self): method test_grad_with_inplace_update (line 890) | def test_grad_with_inplace_update(self): method test_autograd_types (line 904) | def test_autograd_types(self): method test_reduce_jvp (line 953) | def test_reduce_jvp(self): FILE: python/tests/test_bf16.py class TestBF16 (line 19) | class TestBF16(mlx_tests.MLXTestCase): method __test_ops (line 20) | def __test_ops( method __default_test (line 37) | def __default_test( method test_unary_ops (line 100) | def test_unary_ops(self): method test_binary_ops (line 107) | def test_binary_ops(self): method test_reduction_ops (line 120) | def test_reduction_ops(self): method test_arg_reduction_ops (line 137) | def test_arg_reduction_ops(self): method test_blas_ops (line 160) | def test_blas_ops(self): method test_conversion (line 187) | def test_conversion(self): FILE: python/tests/test_blas.py class TestBlas (line 12) | class TestBlas(mlx_tests.MLXTestCase): method dtypes (line 14) | def dtypes(self): method __gemm_test (line 17) | def __gemm_test( method test_matmul_unaligned (line 49) | def test_matmul_unaligned(self): method test_matvec_unaligned (line 64) | def test_matvec_unaligned(self): method test_matmul_shapes (line 71) | def test_matmul_shapes(self): method test_matmul (line 131) | def test_matmul(self): method test_matmul_dtypes (line 160) | def test_matmul_dtypes(self): method test_matmul_batched (line 176) | def test_matmul_batched(self): method __gemv_test (line 262) | def __gemv_test( method test_matrix_vector (line 308) | def test_matrix_vector(self): method test_matrix_vector_batched (line 349) | def test_matrix_vector_batched(self): method test_matrix_vector_broadcast (line 376) | def test_matrix_vector_broadcast(self): method test_matrix_vector_attn (line 414) | def test_matrix_vector_attn(self): method test_matrix_vector_edgecases (line 485) | def test_matrix_vector_edgecases(self): method test_mismatch_stride_mm (line 538) | def test_mismatch_stride_mm(self): method test_addmm (line 602) | def test_addmm(self): method test_addmm_grad (line 733) | def test_addmm_grad(self): method test_empty_matmul (line 771) | def test_empty_matmul(self): method test_block_masked_matmul (line 838) | def test_block_masked_matmul(self): method test_gather_matmul (line 1048) | def test_gather_matmul(self): method test_gather_matmul_grad (line 1193) | def test_gather_matmul_grad(self): method test_gather_mm_sorted (line 1239) | def test_gather_mm_sorted(self): method test_gather_mm_sorted_vjp (line 1267) | def test_gather_mm_sorted_vjp(self): method test_segmented_mm (line 1294) | def test_segmented_mm(self): method test_gemv_gemm_same_precision (line 1356) | def test_gemv_gemm_same_precision(self): method test_complex_gemv (line 1368) | def test_complex_gemv(self): method test_complex_gemm (line 1398) | def test_complex_gemm(self): FILE: python/tests/test_compile.py class TestCompile (line 15) | class TestCompile(mlx_tests.MLXTestCase): method test_simple_compile (line 16) | def test_simple_compile(self): method test_compile_grad (line 47) | def test_compile_grad(self): method test_compile_inputs_with_primitives (line 80) | def test_compile_inputs_with_primitives(self): method test_compile_with_closure (line 105) | def test_compile_with_closure(self): method test_function_creates_array (line 160) | def test_function_creates_array(self): method test_enable_disable (line 172) | def test_enable_disable(self): method test_compile_two_input_grad (line 198) | def test_compile_two_input_grad(self): method test_vmap_compiled (line 210) | def test_vmap_compiled(self): method test_vjp_vjp_compiled (line 252) | def test_vjp_vjp_compiled(self): method test_transform_over_eval_compiled (line 288) | def test_transform_over_eval_compiled(self): method test_compile_capture (line 309) | def test_compile_capture(self): method test_compile_rng (line 381) | def test_compile_rng(self): method test_compile_kwargs (line 388) | def test_compile_kwargs(self): method test_shapeless_compile (line 399) | def test_shapeless_compile(self): method test_shapeless_compile_with_broadcasts (line 423) | def test_shapeless_compile_with_broadcasts(self): method test_shapeless_compile_with_reduction (line 437) | def test_shapeless_compile_with_reduction(self): method test_shapeless_compile_unflatten (line 469) | def test_shapeless_compile_unflatten(self): method test_shapeless_compile_gather (line 477) | def test_shapeless_compile_gather(self): method test_shapeless_compile_full_like (line 485) | def test_shapeless_compile_full_like(self): method test_compile_with_constant (line 507) | def test_compile_with_constant(self): method test_compile_inf (line 630) | def test_compile_inf(self): method test_unsupported_input_types (line 638) | def test_unsupported_input_types(self): method test_compile_create_list (line 652) | def test_compile_create_list(self): method test_compile_vjp (line 660) | def test_compile_vjp(self): method test_shapeless_mean (line 700) | def test_shapeless_mean(self): method test_compile_broadcast_only (line 722) | def test_compile_broadcast_only(self): method test_compile_with_long_name (line 732) | def test_compile_with_long_name(self): method test_compile_multi_output (line 742) | def test_compile_multi_output(self): method test_inf_constant (line 755) | def test_inf_constant(self): method test_max_into_equal (line 762) | def test_max_into_equal(self): method test_dtypes (line 774) | def test_dtypes(self): method test_compile_without_captured_inputs (line 788) | def test_compile_without_captured_inputs(self): method test_compile_dynamic_dims (line 813) | def test_compile_dynamic_dims(self): method test_compile_many_inputs (line 826) | def test_compile_many_inputs(self): method test_compile_many_outputs (line 866) | def test_compile_many_outputs(self): method test_shapeless_compile_matmul (line 880) | def test_shapeless_compile_matmul(self): method test_shapeless_compile_slice_update (line 887) | def test_shapeless_compile_slice_update(self): method test_shapeless_compile_with_reshape (line 900) | def test_shapeless_compile_with_reshape(self): method test_compile_shapeless_with_broadcast (line 919) | def test_compile_shapeless_with_broadcast(self): method test_leaks (line 992) | def test_leaks(self): method test_double_constant (line 1019) | def test_double_constant(self): method test_shared_broadcast (line 1030) | def test_shared_broadcast(self): method test_compile_large_graph_with_broadcasts (line 1052) | def test_compile_large_graph_with_broadcasts(self): method test_wrap_compiled (line 1074) | def test_wrap_compiled(self): method test_compiled_preserves_attributes (line 1083) | def test_compiled_preserves_attributes(self): method test_compile_with_none (line 1096) | def test_compile_with_none(self): method test_compile_changing_outputs (line 1110) | def test_compile_changing_outputs(self): method test_compile_changing_outputs_with_state (line 1142) | def test_compile_changing_outputs_with_state(self): method test_outputs_changing (line 1161) | def test_outputs_changing(self): method test_multiple_compile_same_capture (line 1180) | def test_multiple_compile_same_capture(self): method test_compile_types (line 1204) | def test_compile_types(self): method test_compile_output_with_siblings (line 1248) | def test_compile_output_with_siblings(self): method test_compile_donates_input_buffer (line 1277) | def test_compile_donates_input_buffer(self): FILE: python/tests/test_constants.py class TestConstants (line 10) | class TestConstants(mlx_tests.MLXTestCase): method test_constants_values (line 11) | def test_constants_values(self): method test_constants_availability (line 24) | def test_constants_availability(self): method test_newaxis_for_reshaping_arrays (line 33) | def test_newaxis_for_reshaping_arrays(self): FILE: python/tests/test_conv.py class TestConv (line 20) | class TestConv(mlx_tests.MLXTestCase): method test_numpy_conv (line 21) | def test_numpy_conv(self): method test_conv_1d_groups_flipped (line 50) | def test_conv_1d_groups_flipped(self): method test_torch_conv_1D (line 58) | def test_torch_conv_1D(self): method test_torch_conv_1D_grad (line 167) | def test_torch_conv_1D_grad(self): method test_torch_conv_2D (line 273) | def test_torch_conv_2D(self): method test_torch_conv_2D_grad (line 373) | def test_torch_conv_2D_grad(self): method test_torch_conv_3D (line 483) | def test_torch_conv_3D(self): method test_torch_conv_3D_grad (line 567) | def test_torch_conv_3D_grad(self): method __conv_general_test (line 685) | def __conv_general_test( method test_torch_conv_general (line 793) | def test_torch_conv_general(self): method test_conv_general_flip_grad (line 884) | def test_conv_general_flip_grad(self): method test_conv_groups_grad (line 953) | def test_conv_groups_grad(self): method test_repeated_conv (line 1046) | def test_repeated_conv(self): method test_torch_conv_depthwise (line 1055) | def test_torch_conv_depthwise(self): method test_asymmetric_padding (line 1091) | def test_asymmetric_padding(self): method test_basic_grad_shapes (line 1132) | def test_basic_grad_shapes(self): method test_conv_1d_with_2d (line 1154) | def test_conv_1d_with_2d(self): method test_conv2d_unaligned_channels (line 1175) | def test_conv2d_unaligned_channels(self): method test_conv2d_large_filter_small_channels (line 1188) | def test_conv2d_large_filter_small_channels(self): FILE: python/tests/test_conv_transpose.py class TestConvTranspose (line 20) | class TestConvTranspose(mlx_tests.MLXTestCase): method test_torch_conv_transpose_1D (line 22) | def test_torch_conv_transpose_1D(self): method test_torch_conv_transpose_1D_grad (line 131) | def test_torch_conv_transpose_1D_grad(self): method test_torch_conv_transpose_2D (line 229) | def test_torch_conv_transpose_2D(self): method test_torch_conv_transpose_2D_grad (line 319) | def test_torch_conv_transpose_2D_grad(self): method test_torch_conv_transpose_3D (line 416) | def test_torch_conv_transpose_3D(self): method test_torch_conv_transpose_3D_grad (line 494) | def test_torch_conv_transpose_3D_grad(self): method test_torch_conv_tranpose_1d_output_padding (line 600) | def test_torch_conv_tranpose_1d_output_padding(self): method test_torch_conv_transpose_2d_output_padding (line 656) | def test_torch_conv_transpose_2d_output_padding(self): method test_torch_conv_transpose_3d_output_padding (line 731) | def test_torch_conv_transpose_3d_output_padding(self): FILE: python/tests/test_device.py class TestDefaultDevice (line 10) | class TestDefaultDevice(unittest.TestCase): method test_mlx_default_device (line 11) | def test_mlx_default_device(self): class TestDevice (line 24) | class TestDevice(mlx_tests.MLXTestCase): method test_device (line 25) | def test_device(self): method test_device_context (line 42) | def test_device_context(self): method test_op_on_device (line 52) | def test_op_on_device(self): class TestStream (line 67) | class TestStream(mlx_tests.MLXTestCase): method test_stream (line 68) | def test_stream(self): method test_op_on_stream (line 96) | def test_op_on_stream(self): class TestDeviceInfo (line 116) | class TestDeviceInfo(mlx_tests.MLXTestCase): method test_device_count (line 117) | def test_device_count(self): method test_device_info_cpu (line 126) | def test_device_info_cpu(self): method test_device_info_gpu (line 134) | def test_device_info_gpu(self): method test_device_info_default (line 143) | def test_device_info_default(self): FILE: python/tests/test_double.py class TestDouble (line 12) | class TestDouble(mlx_tests.MLXTestCase): method test_unary_ops (line 13) | def test_unary_ops(self): method test_binary_ops (line 59) | def test_binary_ops(self): method test_where (line 98) | def test_where(self): method test_reductions (line 115) | def test_reductions(self): method test_get_and_set_item (line 133) | def test_get_and_set_item(self): method test_gemm (line 158) | def test_gemm(self): method test_type_promotion (line 176) | def test_type_promotion(self): method test_lapack (line 186) | def test_lapack(self): method test_conversion (line 290) | def test_conversion(self): method test_linspace (line 299) | def test_linspace(self): FILE: python/tests/test_einsum.py class TestEinsum (line 10) | class TestEinsum(mlx_tests.MLXTestCase): method test_simple_path (line 12) | def test_simple_path(self): method test_longer_paths (line 43) | def test_longer_paths(self): method test_simple_einsum (line 68) | def test_simple_einsum(self): method test_two_input_einsum (line 136) | def test_two_input_einsum(self): method test_sum_first (line 180) | def test_sum_first(self): method test_broadcasting (line 187) | def test_broadcasting(self): method test_attention (line 201) | def test_attention(self): method test_multi_input_einsum (line 214) | def test_multi_input_einsum(self): method test_opt_einsum_test_cases (line 220) | def test_opt_einsum_test_cases(self): method test_ellipses (line 316) | def test_ellipses(self): FILE: python/tests/test_eval.py class TestEval (line 10) | class TestEval(mlx_tests.MLXTestCase): method test_eval (line 11) | def test_eval(self): method test_retain_graph (line 17) | def test_retain_graph(self): method test_eval_mixed (line 27) | def test_eval_mixed(self): method test_async_eval (line 35) | def test_async_eval(self): method test_async_eval_twice (line 53) | def test_async_eval_twice(self): method test_async_eval_in_trace (line 62) | def test_async_eval_in_trace(self): method test_async_eval_into_eval (line 76) | def test_async_eval_into_eval(self): method test_async_eval_into_eval_diff_stream (line 84) | def test_async_eval_into_eval_diff_stream(self): method test_eval_slow_fast_multi_stream (line 92) | def test_eval_slow_fast_multi_stream(self): method test_multi_output_eval_during_transform (line 108) | def test_multi_output_eval_during_transform(self): method test_async_eval_with_multiple_streams (line 124) | def test_async_eval_with_multiple_streams(self): method test_donation_for_noops (line 139) | def test_donation_for_noops(self): method test_multistream_deadlock (line 176) | def test_multistream_deadlock(self): FILE: python/tests/test_export_import.py class TestExportImport (line 13) | class TestExportImport(mlx_tests.MLXTestCase): method setUpClass (line 16) | def setUpClass(cls): method tearDownClass (line 23) | def tearDownClass(cls): method test_basic_export_import (line 26) | def test_basic_export_import(self): method test_export_random_sample (line 134) | def test_export_random_sample(self): method test_export_with_kwargs (line 152) | def test_export_with_kwargs(self): method test_export_variable_inputs (line 203) | def test_export_variable_inputs(self): method test_leaks (line 244) | def test_leaks(self): method test_export_import_shapeless (line 272) | def test_export_import_shapeless(self): method test_export_scatter_gather (line 290) | def test_export_scatter_gather(self): method test_export_conv (line 316) | def test_export_conv(self): method test_export_conv_shapeless (line 349) | def test_export_conv_shapeless(self): method test_export_control_flow (line 448) | def test_export_control_flow(self): method test_export_quantized_model (line 466) | def test_export_quantized_model(self): method test_export_kwarg_ordering (line 488) | def test_export_kwarg_ordering(self): method test_export_with_callback (line 501) | def test_export_with_callback(self): method test_export_import_custom_kernel (line 541) | def test_export_import_custom_kernel(self): method test_export_import_multi_with_constants (line 584) | def test_export_import_multi_with_constants(self): method test_export_import_scatter_sum (line 605) | def test_export_import_scatter_sum(self): FILE: python/tests/test_fast.py function rope_orig (line 10) | def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): function rms_norm (line 51) | def rms_norm(x, weight, eps): function layer_norm (line 57) | def layer_norm(x, weight, bias, eps): class TestFast (line 71) | class TestFast(mlx_tests.MLXTestCase): method test_rope (line 72) | def test_rope(self): method test_rope_dims_validation (line 172) | def test_rope_dims_validation(self): method test_rope_with_freqs (line 207) | def test_rope_with_freqs(self): method test_rope_grad (line 292) | def test_rope_grad(self): method test_rope_batch (line 319) | def test_rope_batch(self): method test_rope_with_large_offset (line 368) | def test_rope_with_large_offset(self): method test_rms_norm (line 388) | def test_rms_norm(self): method test_rms_norm_grad (line 448) | def test_rms_norm_grad(self): method test_layer_norm_dim_check (line 491) | def test_layer_norm_dim_check(self): method test_layer_norm (line 502) | def test_layer_norm(self): method test_slice_into_layer_norm (line 588) | def test_slice_into_layer_norm(self): method test_layer_norm_grad (line 596) | def test_layer_norm_grad(self): method test_layer_norm_grad_no_bias (line 639) | def test_layer_norm_grad_no_bias(self): method test_layer_norm_grad_no_params (line 675) | def test_layer_norm_grad_no_params(self): method test_layer_norm_grad_params (line 686) | def test_layer_norm_grad_params(self): method test_fast_transforms (line 701) | def test_fast_transforms(self): method test_custom_kernel_basic (line 740) | def test_custom_kernel_basic(self): method test_custom_kernel_args (line 773) | def test_custom_kernel_args(self): method test_custom_kernel_strides (line 832) | def test_custom_kernel_strides(self): method test_custom_kernel_helper (line 886) | def test_custom_kernel_helper(self): method test_custom_kernel_attributes (line 932) | def test_custom_kernel_attributes(self): method test_custom_kernel_caching (line 958) | def test_custom_kernel_caching(self): FILE: python/tests/test_fast_sdpa.py function mlx_ref_attn (line 10) | def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): function do_attention (line 66) | def do_attention(f, q, k, v, scale, mask=None, transpose=False): function prepare_inputs (line 77) | def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): function mlx_primitives_sdpa (line 98) | def mlx_primitives_sdpa(q, k, v, scale, mask=None): class TestFastSDPA (line 118) | class TestFastSDPA(mlx_tests.MLXTestCase): method test_sdpa_vector_kv_transposed_head_seq (line 119) | def test_sdpa_vector_kv_transposed_head_seq(self): method test_sdpa_vector (line 151) | def test_sdpa_vector(self): method test_sdpa_fully_masked (line 223) | def test_sdpa_fully_masked(self): method test_sdpa_inf_score (line 235) | def test_sdpa_inf_score(self): method test_sdpa_few_query (line 247) | def test_sdpa_few_query(self): method test_sdpa_vector_value_dims (line 306) | def test_sdpa_vector_value_dims(self): method test_sdpa_vector_batched (line 322) | def test_sdpa_vector_batched(self): method test_sdpa (line 363) | def test_sdpa(self): method test_sdpa_broadcast_mask (line 445) | def test_sdpa_broadcast_mask(self): method test_sdpa_noncontiguous_inputs (line 461) | def test_sdpa_noncontiguous_inputs(self): method test_sdpa_promote_mask (line 472) | def test_sdpa_promote_mask(self): method test_sdpa_nan_bug (line 488) | def test_sdpa_nan_bug(self): method test_sdpa_attention_sinks (line 516) | def test_sdpa_attention_sinks(self): method test_sdpa_grad (line 566) | def test_sdpa_grad(self): method test_sdpa_sliced (line 613) | def test_sdpa_sliced(self): FILE: python/tests/test_fft.py class TestFFT (line 18) | class TestFFT(mlx_tests.MLXTestCase): method check_mx_np (line 19) | def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwar... method test_fft (line 25) | def test_fft(self): method test_fftn (line 62) | def test_fftn(self): method _run_ffts (line 105) | def _run_ffts(self, shape, atol=1e-4, rtol=1e-4): method test_fft_shared_mem (line 122) | def test_fft_shared_mem(self): method test_fft_exhaustive (line 145) | def test_fft_exhaustive(self): method test_fft_big_powers_of_two (line 153) | def test_fft_big_powers_of_two(self): method test_fft_large_numbers (line 161) | def test_fft_large_numbers(self): method test_fft_contiguity (line 174) | def test_fft_contiguity(self): method test_fft_into_ifft (line 203) | def test_fft_into_ifft(self): method test_fft_throws (line 215) | def test_fft_throws(self): method test_fftshift (line 220) | def test_fftshift(self): method test_ifftshift (line 245) | def test_ifftshift(self): method test_fftshift_errors (line 270) | def test_fftshift_errors(self): method test_fft_grads (line 283) | def test_fft_grads(self): FILE: python/tests/test_graph.py class TestGraph (line 10) | class TestGraph(mlx_tests.MLXTestCase): method test_to_dot (line 11) | def test_to_dot(self): FILE: python/tests/test_init.py class TestInit (line 10) | class TestInit(mlx_tests.MLXTestCase): method test_constant (line 11) | def test_constant(self): method test_normal (line 22) | def test_normal(self): method test_uniform (line 33) | def test_uniform(self): method test_identity (line 46) | def test_identity(self): method test_glorot_normal (line 56) | def test_glorot_normal(self): method test_glorot_uniform (line 65) | def test_glorot_uniform(self): method test_he_normal (line 74) | def test_he_normal(self): method test_he_uniform (line 83) | def test_he_uniform(self): method test_sparse (line 92) | def test_sparse(self): method test_orthogonal (line 109) | def test_orthogonal(self): FILE: python/tests/test_linalg.py class TestLinalg (line 12) | class TestLinalg(mlx_tests.MLXTestCase): method test_norm (line 13) | def test_norm(self): method test_complex_norm (line 68) | def test_complex_norm(self): method test_qr_factorization (line 95) | def test_qr_factorization(self): method test_svd_decomposition (line 137) | def test_svd_decomposition(self): method test_inverse (line 207) | def test_inverse(self): method test_tri_inverse (line 221) | def test_tri_inverse(self): method test_cholesky (line 242) | def test_cholesky(self): method test_pseudo_inverse (line 259) | def test_pseudo_inverse(self): method test_cholesky_inv (line 276) | def test_cholesky_inv(self): method test_cross_product (line 303) | def test_cross_product(self): method test_eig (line 351) | def test_eig(self): method test_eigh (line 435) | def test_eigh(self): method test_lu (line 493) | def test_lu(self): method test_lu_factor (line 523) | def test_lu_factor(self): method test_solve (line 540) | def test_solve(self): method test_solve_triangular (line 594) | def test_solve_triangular(self): FILE: python/tests/test_load.py class TestLoad (line 14) | class TestLoad(mlx_tests.MLXTestCase): method setUpClass (line 30) | def setUpClass(cls): method tearDownClass (line 37) | def tearDownClass(cls): method test_save_and_load (line 40) | def test_save_and_load(self): method test_load_npy_dtype (line 75) | def test_load_npy_dtype(self): method test_save_and_load_safetensors (line 90) | def test_save_and_load_safetensors(self): method test_save_and_load_gguf (line 131) | def test_save_and_load_gguf(self): method test_load_f8_e4m3 (line 167) | def test_load_f8_e4m3(self): method test_save_and_load_gguf_metadata_basic (line 192) | def test_save_and_load_gguf_metadata_basic(self): method test_save_and_load_gguf_metadata_arrays (line 225) | def test_save_and_load_gguf_metadata_arrays(self): method test_save_and_load_gguf_metadata_mixed (line 261) | def test_save_and_load_gguf_metadata_mixed(self): method test_save_and_load_fs (line 303) | def test_save_and_load_fs(self): method test_savez_and_loadz (line 341) | def test_savez_and_loadz(self): method test_non_contiguous (line 391) | def test_non_contiguous(self): method test_load_donation (line 425) | def test_load_donation(self): FILE: python/tests/test_losses.py class TestLosses (line 11) | class TestLosses(mlx_tests.MLXTestCase): method test_cross_entropy (line 12) | def test_cross_entropy(self): method test_binary_cross_entropy (line 79) | def test_binary_cross_entropy(self): method test_l1_loss (line 167) | def test_l1_loss(self): method test_mse_loss (line 188) | def test_mse_loss(self): method test_smooth_l1_loss (line 217) | def test_smooth_l1_loss(self): method test_nll_loss (line 254) | def test_nll_loss(self): method test_gaussian_nll_loss (line 273) | def test_gaussian_nll_loss(self): method test_kl_div_loss (line 318) | def test_kl_div_loss(self): method test_triplet_loss (line 337) | def test_triplet_loss(self): method test_hinge_loss (line 363) | def test_hinge_loss(self): method test_huber_loss (line 369) | def test_huber_loss(self): method test_log_cosh_loss (line 375) | def test_log_cosh_loss(self): method test_cosine_similarity_loss (line 381) | def test_cosine_similarity_loss(self): method test_margin_ranking_loss (line 406) | def test_margin_ranking_loss(self): FILE: python/tests/test_memory.py class TestMemory (line 9) | class TestMemory(mlx_tests.MLXTestCase): method test_memory_info (line 10) | def test_memory_info(self): method test_wired_memory (line 52) | def test_wired_memory(self): method test_active_memory_count (line 61) | def test_active_memory_count(self): FILE: python/tests/test_nn.py class TestBase (line 14) | class TestBase(mlx_tests.MLXTestCase): method test_module_utilities (line 15) | def test_module_utilities(self): method test_module_attributes (line 57) | def test_module_attributes(self): method test_model_with_dict (line 76) | def test_model_with_dict(self): method test_save_npz_weights (line 88) | def test_save_npz_weights(self): method test_save_safetensors_weights (line 106) | def test_save_safetensors_weights(self): method test_load_from_weights (line 124) | def test_load_from_weights(self): method test_module_state (line 190) | def test_module_state(self): method test_chaining (line 195) | def test_chaining(self): method test_quantize (line 205) | def test_quantize(self): method test_quantize_freeze (line 245) | def test_quantize_freeze(self): method test_quantized_sharded_linear_construction (line 252) | def test_quantized_sharded_linear_construction(self): method test_grad_of_module (line 264) | def test_grad_of_module(self): method test_update (line 278) | def test_update(self): method test_update_modules (line 295) | def test_update_modules(self): method test_parameter_deletion (line 333) | def test_parameter_deletion(self): method test_circular_leaks (line 338) | def test_circular_leaks(self): class TestLayers (line 356) | class TestLayers(mlx_tests.MLXTestCase): method test_identity (line 357) | def test_identity(self): method test_linear (line 363) | def test_linear(self): method test_bilinear (line 369) | def test_bilinear(self): method test_group_norm (line 376) | def test_group_norm(self): method test_instance_norm (line 412) | def test_instance_norm(self): method test_batch_norm (line 630) | def test_batch_norm(self): method test_batch_norm_stats (line 747) | def test_batch_norm_stats(self): method test_conv1d (line 788) | def test_conv1d(self): method test_conv2d (line 824) | def test_conv2d(self): method test_sequential (line 878) | def test_sequential(self): method test_gelu (line 894) | def test_gelu(self): method test_sin_pe (line 924) | def test_sin_pe(self): method test_sigmoid (line 935) | def test_sigmoid(self): method test_relu (line 944) | def test_relu(self): method test_leaky_relu (line 951) | def test_leaky_relu(self): method test_elu (line 963) | def test_elu(self): method test_relu6 (line 979) | def test_relu6(self): method test_softmax (line 986) | def test_softmax(self): method test_softmin (line 995) | def test_softmin(self): method test_softplus (line 1004) | def test_softplus(self): method test_softsign (line 1013) | def test_softsign(self): method test_softshrink (line 1022) | def test_softshrink(self): method test_celu (line 1037) | def test_celu(self): method test_log_softmax (line 1052) | def test_log_softmax(self): method test_log_sigmoid (line 1061) | def test_log_sigmoid(self): method test_prelu (line 1070) | def test_prelu(self): method test_mish (line 1076) | def test_mish(self): method test_hardswish (line 1082) | def test_hardswish(self): method test_glu (line 1091) | def test_glu(self): method test_hard_tanh (line 1097) | def test_hard_tanh(self): method test_hard_shrink (line 1105) | def test_hard_shrink(self): method test_rope (line 1119) | def test_rope(self): method test_alibi (line 1134) | def test_alibi(self): method test_dropout (line 1145) | def test_dropout(self): method test_dropout2d (line 1161) | def test_dropout2d(self): method test_dropout3d (line 1177) | def test_dropout3d(self): method test_upsample (line 1193) | def test_upsample(self): method test_pooling (line 1415) | def test_pooling(self): method test_set_dtype (line 1877) | def test_set_dtype(self): method test_rnn (line 1900) | def test_rnn(self): method test_gru (line 1927) | def test_gru(self): method test_lstm (line 1952) | def test_lstm(self): method test_quantized_embedding (line 1974) | def test_quantized_embedding(self): method test_causal_mask (line 1994) | def test_causal_mask(self): method test_attention (line 2003) | def test_attention(self): FILE: python/tests/test_ops.py function np_wrap_between (line 13) | def np_wrap_between(x, a): function np_logaddexp (line 25) | def np_logaddexp(x1: np.ndarray, x2: np.ndarray): function np_cumlogaddexp (line 47) | def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): class TestOps (line 54) | class TestOps(mlx_tests.MLXTestCase): method test_full_ones_zeros (line 55) | def test_full_ones_zeros(self): method test_scalar_inputs (line 95) | def test_scalar_inputs(self): method test_add (line 181) | def test_add(self): method test_subtract (line 234) | def test_subtract(self): method test_multiply (line 250) | def test_multiply(self): method test_divide (line 266) | def test_divide(self): method test_remainder (line 300) | def test_remainder(self): method test_comparisons (line 336) | def test_comparisons(self): method test_array_equal (line 356) | def test_array_equal(self): method test_isnan (line 383) | def test_isnan(self): method test_isinf (line 398) | def test_isinf(self): method test_isfinite (line 421) | def test_isfinite(self): method test_tri (line 431) | def test_tri(self): method test_tril (line 438) | def test_tril(self): method test_triu (line 445) | def test_triu(self): method test_minimum (line 451) | def test_minimum(self): method test_maximum (line 463) | def test_maximum(self): method test_floor (line 475) | def test_floor(self): method test_ceil (line 483) | def test_ceil(self): method test_isposinf (line 491) | def test_isposinf(self): method test_isneginf (line 514) | def test_isneginf(self): method test_round (line 537) | def test_round(self): method test_transpose_noargs (line 575) | def test_transpose_noargs(self): method test_transpose_axis (line 586) | def test_transpose_axis(self): method test_move_swap_axes (line 600) | def test_move_swap_axes(self): method test_sum (line 607) | def test_sum(self): method test_prod (line 659) | def test_prod(self): method test_min_and_max (line 674) | def test_min_and_max(self): method test_argmin_argmax (line 696) | def test_argmin_argmax(self): method test_broadcast (line 714) | def test_broadcast(self): method test_logsumexp (line 733) | def test_logsumexp(self): method test_mean (line 763) | def test_mean(self): method test_median (line 778) | def test_median(self): method test_var (line 811) | def test_var(self): method test_std (line 834) | def test_std(self): method test_abs (line 839) | def test_abs(self): method test_negative (line 847) | def test_negative(self): method test_sign (line 853) | def test_sign(self): method test_logical_not (line 868) | def test_logical_not(self): method test_logical_and (line 874) | def test_logical_and(self): method test_logical_or (line 885) | def test_logical_or(self): method test_square (line 896) | def test_square(self): method test_sqrt (line 903) | def test_sqrt(self): method test_rsqrt (line 909) | def test_rsqrt(self): method test_reciprocal (line 915) | def test_reciprocal(self): method test_logaddexp (line 921) | def test_logaddexp(self): method test_log (line 944) | def test_log(self): method test_log2 (line 956) | def test_log2(self): method test_log10 (line 968) | def test_log10(self): method test_exp (line 980) | def test_exp(self): method test_expm1 (line 987) | def test_expm1(self): method test_erf (line 995) | def test_erf(self): method test_erfinv (line 1001) | def test_erfinv(self): method test_sin (line 1023) | def test_sin(self): method test_cos (line 1032) | def test_cos(self): method test_degrees (line 1041) | def test_degrees(self): method test_radians (line 1050) | def test_radians(self): method test_log1p (line 1057) | def test_log1p(self): method test_sigmoid (line 1071) | def test_sigmoid(self): method test_allclose (line 1083) | def test_allclose(self): method test_isclose (line 1098) | def test_isclose(self): method test_all (line 1110) | def test_all(self): method test_any (line 1121) | def test_any(self): method test_stop_gradient (line 1132) | def test_stop_gradient(self): method test_kron (line 1141) | def test_kron(self): method test_take (line 1169) | def test_take(self): method test_take_along_axis (line 1253) | def test_take_along_axis(self): method test_put_along_axis (line 1269) | def test_put_along_axis(self): method test_split (line 1310) | def test_split(self): method test_split_invalid_num_splits (line 1331) | def test_split_invalid_num_splits(self): method test_arange_overload_dispatch (line 1350) | def test_arange_overload_dispatch(self): method test_arange_inferred_dtype (line 1390) | def test_arange_inferred_dtype(self): method test_arange_corner_cases_cast (line 1418) | def test_arange_corner_cases_cast(self): method test_hanning_general (line 1472) | def test_hanning_general(self): method test_hamming_general (line 1484) | def test_hamming_general(self): method test_bartlett_general (line 1496) | def test_bartlett_general(self): method test_blackman_general (line 1508) | def test_blackman_general(self): method test_unary_ops (line 1520) | def test_unary_ops(self): method test_unary_ops_from_non_array (line 1538) | def test_unary_ops_from_non_array(self): method test_trig_ops (line 1587) | def test_trig_ops(self): method test_binary_ops (line 1682) | def test_binary_ops(self): method test_irregular_binary_ops (line 1751) | def test_irregular_binary_ops(self): method test_softmax (line 1800) | def test_softmax(self): method test_concatenate (line 1850) | def test_concatenate(self): method test_meshgrid (line 1874) | def test_meshgrid(self): method test_pad (line 1937) | def test_pad(self): method test_where (line 1980) | def test_where(self): method test_nan_to_num (line 2001) | def test_nan_to_num(self): method test_as_strided (line 2018) | def test_as_strided(self): method test_logcumsumexp (line 2036) | def test_logcumsumexp(self): method test_scans (line 2068) | def test_scans(self): method test_squeeze_expand (line 2177) | def test_squeeze_expand(self): method test_sort (line 2193) | def test_sort(self): method test_partition (line 2312) | def test_partition(self): method test_argpartition (line 2343) | def test_argpartition(self): method test_large_binary (line 2358) | def test_large_binary(self): method test_eye (line 2363) | def test_eye(self): method test_stack (line 2372) | def test_stack(self): method test_flatten (line 2403) | def test_flatten(self): method test_clip (line 2412) | def test_clip(self): method test_linspace (line 2456) | def test_linspace(self): method test_repeat (line 2490) | def test_repeat(self): method test_tensordot (line 2511) | def test_tensordot(self): method test_inner (line 2541) | def test_inner(self): method test_outer (line 2546) | def test_outer(self): method test_divmod (line 2569) | def test_divmod(self): method test_tile (line 2591) | def test_tile(self): method test_empty_matmuls (line 2608) | def test_empty_matmuls(self): method test_diagonal (line 2618) | def test_diagonal(self): method test_diag (line 2632) | def test_diag(self): method test_trace (line 2674) | def test_trace(self): method test_atleast_1d (line 2706) | def test_atleast_1d(self): method test_atleast_2d (line 2728) | def test_atleast_2d(self): method test_atleast_3d (line 2750) | def test_atleast_3d(self): method test_issubdtype (line 2772) | def test_issubdtype(self): method test_bitwise_ops (line 2806) | def test_bitwise_ops(self): method test_bitwise_grad (line 2857) | def test_bitwise_grad(self): method test_conjugate (line 2872) | def test_conjugate(self): method test_view (line 2885) | def test_view(self): method _hadamard (line 2913) | def _hadamard(self, N): method test_hadamard (line 2920) | def test_hadamard(self): method test_hadamard_grad_vmap (line 3015) | def test_hadamard_grad_vmap(self): method test_roll (line 3043) | def test_roll(self): method test_roll_errors (line 3077) | def test_roll_errors(self): method test_real_imag (line 3082) | def test_real_imag(self): method test_dynamic_slicing (line 3097) | def test_dynamic_slicing(self): method test_broadcast_arrays (line 3110) | def test_broadcast_arrays(self): method test_slice_update_reversed (line 3123) | def test_slice_update_reversed(self): method test_slice_with_negative_stride (line 3129) | def test_slice_with_negative_stride(self): method test_complex_ops (line 3139) | def test_complex_ops(self): method test_complex_power (line 3168) | def test_complex_power(self): method test_irregular_alignments (line 3175) | def test_irregular_alignments(self): method test_integer_power (line 3194) | def test_integer_power(self): method test_depends (line 3201) | def test_depends(self): method test_masked_scatter (line 3214) | def test_masked_scatter(self): method test_broadcast_shapes (line 3275) | def test_broadcast_shapes(self): method test_sort_nan (line 3313) | def test_sort_nan(self): method test_argsort_nan (line 3322) | def test_argsort_nan(self): method test_to_from_fp8 (line 3331) | def test_to_from_fp8(self): FILE: python/tests/test_optimizers.py function get_all_optimizers (line 25) | def get_all_optimizers(): function tree_equal (line 37) | def tree_equal(fn, *args): class TestOptimizers (line 45) | class TestOptimizers(mlx_tests.MLXTestCase): method test_optimizer_state (line 46) | def test_optimizer_state(self): method test_optimizers (line 54) | def test_optimizers(self): method test_types_conserved (line 69) | def test_types_conserved(self): method test_sgd (line 77) | def test_sgd(self): method test_rmsprop (line 102) | def test_rmsprop(self): method test_adagrad (line 130) | def test_adagrad(self): method test_adadelta (line 148) | def test_adadelta(self): method test_adam (line 173) | def test_adam(self): method test_adamw_matches_pytorch (line 207) | def test_adamw_matches_pytorch(self): method test_lion (line 251) | def test_lion(self): method test_adafactor (line 269) | def test_adafactor(self): method test_muon (line 289) | def test_muon(self): method test_compiled_optimizer (line 336) | def test_compiled_optimizer(self): method test_update_lr_compiled (line 394) | def test_update_lr_compiled(self): class TestSchedulers (line 410) | class TestSchedulers(mlx_tests.MLXTestCase): method test_decay_lr (line 411) | def test_decay_lr(self): method test_step_decay (line 424) | def test_step_decay(self): method test_exponential_decay (line 430) | def test_exponential_decay(self): method test_cosine_decay (line 436) | def test_cosine_decay(self): method test_schedule_joiner (line 449) | def test_schedule_joiner(self): method test_linear_warmup_with_cosine_decay (line 463) | def test_linear_warmup_with_cosine_decay(self): method test_compile_with_schedule (line 480) | def test_compile_with_schedule(self): method test_clip_grad_norm (line 492) | def test_clip_grad_norm(self): method test_init_from_state (line 534) | def test_init_from_state(self): method test_multi_optimizer (line 558) | def test_multi_optimizer(self): FILE: python/tests/test_quantized.py class TestQuantized (line 10) | class TestQuantized(mlx_tests.MLXTestCase): method test_quantize_dequantize (line 11) | def test_quantize_dequantize(self): method test_mxfp4_quantize_dequantize (line 30) | def test_mxfp4_quantize_dequantize(self): method test_mxfp8_quantize_dequantize (line 85) | def test_mxfp8_quantize_dequantize(self): method test_nvfp4_quantize_dequantize (line 113) | def test_nvfp4_quantize_dequantize(self): method test_qqmv (line 176) | def test_qqmv(self): method test_qmm (line 208) | def test_qmm(self): method test_qmm_vjp (line 245) | def test_qmm_vjp(self): method test_qmm_jvp (line 275) | def test_qmm_jvp(self): method test_qmm_shapes (line 305) | def test_qmm_shapes(self): method test_qmv (line 330) | def test_qmv(self): method test_fp_qmv (line 357) | def test_fp_qmv(self): method test_qvm (line 408) | def test_qvm(self): method test_qvm_splitk (line 435) | def test_qvm_splitk(self): method test_fp_qvm (line 473) | def test_fp_qvm(self): method test_mode_error_cases (line 506) | def test_mode_error_cases(self): method test_throw (line 571) | def test_throw(self): method test_small_matrix (line 587) | def test_small_matrix(self): method test_non_multiples (line 623) | def test_non_multiples(self): method test_qmv_small_non_multiples (line 697) | def test_qmv_small_non_multiples(self): method test_gather_qmm (line 756) | def test_gather_qmm(self): method test_qmm_fp_type (line 897) | def test_qmm_fp_type(self): method test_gather_matmul_grad (line 916) | def test_gather_matmul_grad(self): method test_gather_qmm_sorted (line 945) | def test_gather_qmm_sorted(self): method test_gather_qmm_grad (line 1059) | def test_gather_qmm_grad(self): method test_vjp_scales_biases (line 1107) | def test_vjp_scales_biases(self): method test_fp_vjp_scales_throws (line 1132) | def test_fp_vjp_scales_throws(self): method test_quantize_strided (line 1160) | def test_quantize_strided(self): FILE: python/tests/test_random.py class TestRandom (line 10) | class TestRandom(mlx_tests.MLXTestCase): method test_global_rng (line 11) | def test_global_rng(self): method test_key (line 23) | def test_key(self): method test_key_split (line 31) | def test_key_split(self): method test_uniform (line 44) | def test_uniform(self): method test_normal_and_laplace (line 67) | def test_normal_and_laplace(self): method test_multivariate_normal (line 112) | def test_multivariate_normal(self): method test_randint (line 202) | def test_randint(self): method test_bernoulli (line 237) | def test_bernoulli(self): method test_truncated_normal (line 260) | def test_truncated_normal(self): method test_gumbel (line 292) | def test_gumbel(self): method test_categorical (line 305) | def test_categorical(self): method test_permutation (line 328) | def test_permutation(self): method test_complex_normal (line 355) | def test_complex_normal(self): method test_broadcastable_scale_loc (line 374) | def test_broadcastable_scale_loc(self): FILE: python/tests/test_reduce.py class TestReduce (line 11) | class TestReduce(mlx_tests.MLXTestCase): method test_axis_permutation_sums (line 12) | def test_axis_permutation_sums(self): method test_expand_sums (line 29) | def test_expand_sums(self): method test_dtypes (line 50) | def test_dtypes(self): method test_arg_reduce (line 86) | def test_arg_reduce(self): method test_edge_case (line 118) | def test_edge_case(self): method test_sum_bool (line 127) | def test_sum_bool(self): method test_many_reduction_axes (line 134) | def test_many_reduction_axes(self): method test_nan_propagation (line 156) | def test_nan_propagation(self): method test_nan_propagation_complex64 (line 182) | def test_nan_propagation_complex64(self): method test_long_column (line 213) | def test_long_column(self): FILE: python/tests/test_tree.py class TestTreeUtils (line 11) | class TestTreeUtils(mlx_tests.MLXTestCase): method test_tree_map (line 12) | def test_tree_map(self): method test_tree_flatten (line 19) | def test_tree_flatten(self): method test_merge (line 26) | def test_merge(self): method test_supported_trees (line 49) | def test_supported_trees(self): FILE: python/tests/test_upsample.py class TestUpsample (line 19) | class TestUpsample(mlx_tests.MLXTestCase): method test_torch_upsample (line 21) | def test_torch_upsample(self): FILE: python/tests/test_vmap.py class TestVmap (line 10) | class TestVmap(mlx_tests.MLXTestCase): method test_basics (line 11) | def test_basics(self): method test_unary (line 33) | def test_unary(self): method test_binary (line 70) | def test_binary(self): method test_tree (line 111) | def test_tree(self): method test_vmap_indexing (line 167) | def test_vmap_indexing(self): method test_vmap_reduce (line 222) | def test_vmap_reduce(self): method test_vmap_argreduce (line 245) | def test_vmap_argreduce(self): method test_vmap_mean (line 255) | def test_vmap_mean(self): method test_mismatch_input_sizes (line 266) | def test_mismatch_input_sizes(self): method test_vmap_matmul (line 277) | def test_vmap_matmul(self): method test_vmap_svd (line 316) | def test_vmap_svd(self): method test_vmap_inverse (line 370) | def test_vmap_inverse(self): method test_vmap_gather (line 400) | def test_vmap_gather(self): method test_vmap_scatter (line 445) | def test_vmap_scatter(self): method test_vmap_const_func (line 537) | def test_vmap_const_func(self): method test_vmap_concatenate (line 557) | def test_vmap_concatenate(self): method test_vmap_take_along_axis (line 577) | def test_vmap_take_along_axis(self): method test_vmap_put_along_axis (line 598) | def test_vmap_put_along_axis(self): method test_vmap_split_vmap (line 624) | def test_vmap_split_vmap(self): method test_leaks (line 636) | def test_leaks(self): method test_vmap_flatten (line 665) | def test_vmap_flatten(self): method test_vmap_conv (line 675) | def test_vmap_conv(self): method test_vmap_types (line 726) | def test_vmap_types(self): method test_vmap_masked_scatter (line 771) | def test_vmap_masked_scatter(self): FILE: setup.py function cuda_toolkit_major_version (line 16) | def cuda_toolkit_major_version(): function get_version (line 25) | def get_version(): class CMakeExtension (line 63) | class CMakeExtension(Extension): method __init__ (line 64) | def __init__(self, name: str, sourcedir: str = "") -> None: class CMakeBuild (line 69) | class CMakeBuild(build_ext): method build_extension (line 70) | def build_extension(self, ext: CMakeExtension) -> None: method run (line 151) | def run(self): class MLXBdistWheel (line 175) | class MLXBdistWheel(bdist_wheel): method get_tag (line 176) | def get_tag(self) -> tuple[str, str, str]: FILE: tests/arg_reduce_tests.cpp function test_arg_reduce_small (line 10) | void test_arg_reduce_small( function test_arg_reduce_against_cpu (line 27) | void test_arg_reduce_against_cpu( function test_arg_reduce_small_bool (line 125) | void test_arg_reduce_small_bool( FILE: tests/compile_tests.cpp function simple_fun (line 13) | std::vector simple_fun(const std::vector& inputs) { function grad_fun (line 40) | std::vector grad_fun(const std::vector& inputs) { function fun_creats_array (line 70) | std::vector fun_creats_array(const std::vector& inputs) { function inner_fun (line 84) | std::vector inner_fun(const std::vector& inputs) { function outer_fun (line 88) | std::vector outer_fun(const std::vector& inputs) { function add_scalars (line 112) | auto add_scalars(const std::vector&) { function max_scalars (line 118) | auto max_scalars(const std::vector&) { function exp_two (line 147) | auto exp_two(const std::vector& inputs) { function add_diff (line 171) | auto add_diff(const std::vector& inputs) { function multi_one (line 184) | auto multi_one(const std::vector&) { function multi_two (line 194) | auto multi_two(const std::vector&) { function multi_three (line 200) | auto multi_three(const std::vector&) { function unary_fused_0 (line 239) | auto unary_fused_0(const std::vector& inputs) { function unary_fused_1 (line 244) | auto unary_fused_1(const std::vector& inputs) { function unary_fused_1_copy (line 248) | auto unary_fused_1_copy(const std::vector& inputs) { function unary_fused_1_diff (line 252) | auto unary_fused_1_diff(const std::vector& inputs) { function unary_fused_2 (line 257) | auto unary_fused_2(const std::vector& inputs) { function unary_fused_3 (line 262) | auto unary_fused_3(const std::vector& inputs) { function binary_fused_0 (line 333) | auto binary_fused_0(const std::vector& inputs) { function binary_fused_1 (line 338) | auto binary_fused_1(const std::vector& inputs) { function binary_fused_2 (line 343) | auto binary_fused_2(const std::vector& inputs) { function binary_fused_3 (line 349) | auto binary_fused_3(const std::vector& inputs) { function gelu_1 (line 409) | auto gelu_1(const std::vector& inputs) { function unary_with_two_outputs (line 447) | auto unary_with_two_outputs(const std::vector& inputs) { function uncompilable_inputs (line 452) | auto uncompilable_inputs(const std::vector& inputs) { function uncompilable_inputs_order_matters (line 458) | auto uncompilable_inputs_order_matters(const std::vector& inputs) { function compile_across_streams (line 511) | auto compile_across_streams(const std::vector& inputs) { function unary_compile_outputs (line 531) | auto unary_compile_outputs(const std::vector& inputs) { function binary_compile_outputs (line 537) | auto binary_compile_outputs(const std::vector& inputs) { function deep_unary_compile (line 575) | auto deep_unary_compile(const std::vector& inputs) { function repeat_input_to_compiled (line 591) | auto repeat_input_to_compiled(const std::vector& inputs) { function compile_unary_inner (line 605) | auto compile_unary_inner(const std::vector& inputs) { function compile_unary_outer (line 610) | auto compile_unary_outer(const std::vector& inputs) { function grad_unary_compiled (line 624) | auto grad_unary_compiled(const std::vector& inputs) { function add3 (line 662) | auto add3(const std::vector& xs) { function compile_shapeless_not_ok (line 685) | auto compile_shapeless_not_ok(const std::vector& inputs) { function compile_shapeless_ok (line 690) | auto compile_shapeless_ok(const std::vector& inputs) { function compile_broadcast_add (line 722) | auto compile_broadcast_add(const std::vector& inputs) { FILE: tests/einsum_tests.cpp type std (line 8) | namespace std { function ostream (line 11) | ostream& operator<<(ostream& os, const vector>&) { FILE: tests/export_import_tests.cpp function get_temp_file (line 15) | std::string get_temp_file(const std::string& name) { FILE: tests/load_tests.cpp function get_temp_file (line 13) | std::string get_temp_file(const std::string& name) { FILE: tests/test_teardown.cpp function main (line 15) | int main() { FILE: tests/tests.cpp function main (line 12) | int main(int argc, char** argv) { FILE: tests/utils_tests.cpp type TestCase (line 30) | struct TestCase {